diff --git a/examples/obj_viewer.py b/examples/obj_viewer.py new file mode 100644 index 0000000000..0c241d8b41 --- /dev/null +++ b/examples/obj_viewer.py @@ -0,0 +1,1419 @@ +# Copyright (c) Meta Platforms, Inc. and its affiliates. +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import ctypes +import math +import os +import string +import sys +import time +from enum import Enum +from typing import Any, Callable, Dict, List, Optional, Tuple + +flags = sys.getdlopenflags() +sys.setdlopenflags(flags | ctypes.RTLD_GLOBAL) + +import habitat.datasets.rearrange.samplers.receptacle as hab_receptacle +import magnum as mn +import numpy as np +from habitat.sims.habitat_simulator.sim_utilities import snap_down +from magnum import shaders, text +from magnum.platform.glfw import Application + +import habitat_sim +from habitat_sim import ReplayRenderer, ReplayRendererConfiguration, physics +from habitat_sim.logging import LoggingContext, logger +from habitat_sim.utils.common import d3_40_colors_rgb, quat_from_angle_axis +from habitat_sim.utils.settings import default_sim_settings, make_cfg + +# add tools directory so I can import things to try them in the viewer +sys.path.append(os.path.join(os.path.dirname(os.path.abspath(__file__)), "../tools")) +print(sys.path) +import collision_shape_automation as csa + +gt_raycast_results = None +pr_raycast_results = None +obj_temp_handle = None +test_points = None + + +class HabitatSimInteractiveViewer(Application): + # the maximum number of chars displayable in the app window + # using the magnum text module. These chars are used to + # display the CPU/GPU usage data + MAX_DISPLAY_TEXT_CHARS = 256 + + # how much to displace window text relative to the center of the + # app window (e.g if you want the display text in the top left of + # the app window, you will displace the text + # window width * -TEXT_DELTA_FROM_CENTER in the x axis and + # window height * TEXT_DELTA_FROM_CENTER in the y axis, as the text + # position defaults to the middle of the app window) + TEXT_DELTA_FROM_CENTER = 0.49 + + # font size of the magnum in-window display text that displays + # CPU and GPU usage info + DISPLAY_FONT_SIZE = 16.0 + + def __init__( + self, + sim_settings: Dict[str, Any], + mm: Optional[habitat_sim.metadata.MetadataMediator] = None, + ) -> None: + self.sim_settings: Dict[str:Any] = sim_settings + self.mm = mm + + self.enable_batch_renderer: bool = self.sim_settings["enable_batch_renderer"] + self.num_env: int = ( + self.sim_settings["num_environments"] if self.enable_batch_renderer else 1 + ) + + # Compute environment camera resolution based on the number of environments to render in the window. + window_size: mn.Vector2 = ( + self.sim_settings["window_width"], + self.sim_settings["window_height"], + ) + + configuration = self.Configuration() + configuration.title = "Habitat Sim Interactive Viewer" + configuration.size = window_size + Application.__init__(self, configuration) + self.fps: float = 60.0 + + # Compute environment camera resolution based on the number of environments to render in the window. + grid_size: mn.Vector2i = ReplayRenderer.environment_grid_size(self.num_env) + camera_resolution: mn.Vector2 = mn.Vector2(self.framebuffer_size) / mn.Vector2( + grid_size + ) + self.sim_settings["width"] = camera_resolution[0] + self.sim_settings["height"] = camera_resolution[1] + + # draw Bullet debug line visualizations (e.g. collision meshes) + self.debug_bullet_draw = False + # draw active contact point debug line visualizations + self.contact_debug_draw = False + # cache most recently loaded URDF file for quick-reload + self.cached_urdf = "" + + # set up our movement map + key = Application.KeyEvent.Key + self.pressed = { + key.UP: False, + key.DOWN: False, + key.LEFT: False, + key.RIGHT: False, + key.A: False, + key.D: False, + key.S: False, + key.W: False, + key.X: False, + key.Z: False, + } + + # set up our movement key bindings map + key = Application.KeyEvent.Key + self.key_to_action = { + key.UP: "look_up", + key.DOWN: "look_down", + key.LEFT: "turn_left", + key.RIGHT: "turn_right", + key.A: "move_left", + key.D: "move_right", + key.S: "move_backward", + key.W: "move_forward", + key.X: "move_down", + key.Z: "move_up", + } + + # Load a TrueTypeFont plugin and open the font file + self.display_font = text.FontManager().load_and_instantiate("TrueTypeFont") + relative_path_to_font = "../data/fonts/ProggyClean.ttf" + self.display_font.open_file( + os.path.join(os.path.dirname(__file__), relative_path_to_font), + 13, + ) + + # Glyphs we need to render everything + self.glyph_cache = text.GlyphCache(mn.Vector2i(256)) + self.display_font.fill_glyph_cache( + self.glyph_cache, + string.ascii_lowercase + + string.ascii_uppercase + + string.digits + + ":-_+,.! %ยต", + ) + + # magnum text object that displays CPU/GPU usage data in the app window + self.window_text = text.Renderer2D( + self.display_font, + self.glyph_cache, + HabitatSimInteractiveViewer.DISPLAY_FONT_SIZE, + text.Alignment.TOP_LEFT, + ) + self.window_text.reserve(HabitatSimInteractiveViewer.MAX_DISPLAY_TEXT_CHARS) + + # text object transform in window space is Projection matrix times Translation Matrix + # put text in top left of window + self.window_text_transform = mn.Matrix3.projection( + self.framebuffer_size + ) @ mn.Matrix3.translation( + mn.Vector2(self.framebuffer_size) + * mn.Vector2( + -HabitatSimInteractiveViewer.TEXT_DELTA_FROM_CENTER, + HabitatSimInteractiveViewer.TEXT_DELTA_FROM_CENTER, + ) + ) + self.shader = shaders.VectorGL2D() + + # make magnum text background transparent + mn.gl.Renderer.enable(mn.gl.Renderer.Feature.BLENDING) + mn.gl.Renderer.set_blend_function( + mn.gl.Renderer.BlendFunction.ONE, + mn.gl.Renderer.BlendFunction.ONE_MINUS_SOURCE_ALPHA, + ) + mn.gl.Renderer.set_blend_equation( + mn.gl.Renderer.BlendEquation.ADD, mn.gl.Renderer.BlendEquation.ADD + ) + + # variables that track app data and CPU/GPU usage + self.num_frames_to_track = 60 + + # Cycle mouse utilities + self.mouse_interaction = MouseMode.LOOK + self.mouse_grabber: Optional[MouseGrabber] = None + self.previous_mouse_point = None + + # toggle physics simulation on/off + self.simulating = False + self.sample_seed = 0 + self.collision_proxy_obj = None + self.mouse_cast_results = None + self.debug_draw_raycasts = True + + self.debug_draw_receptacles = True + self.object_receptacles = [] + + # toggle a single simulation step at the next opportunity if not + # simulating continuously. + self.simulate_single_step = False + + # configure our simulator + self.cfg: Optional[habitat_sim.simulator.Configuration] = None + self.sim: Optional[habitat_sim.simulator.Simulator] = None + self.tiled_sims: list[habitat_sim.simulator.Simulator] = None + self.replay_renderer_cfg: Optional[ReplayRendererConfiguration] = None + self.replay_renderer: Optional[ReplayRenderer] = None + self.reconfigure_sim() + + if self.sim.pathfinder.is_loaded: + self.sim.pathfinder = habitat_sim.nav.PathFinder() + + # compute NavMesh if not already loaded by the scene. + # if ( + # not self.sim.pathfinder.is_loaded + # and self.cfg.sim_cfg.scene_id.lower() != "none" + # ): + # self.navmesh_config_and_recompute() + + self.time_since_last_simulation = 0.0 + LoggingContext.reinitialize_from_env() + logger.setLevel("INFO") + self.print_help_text() + + def draw_contact_debug(self): + """ + This method is called to render a debug line overlay displaying active contact points and normals. + Yellow lines show the contact distance along the normal and red lines show the contact normal at a fixed length. + """ + yellow = mn.Color4.yellow() + red = mn.Color4.red() + cps = self.sim.get_physics_contact_points() + self.sim.get_debug_line_render().set_line_width(1.5) + camera_position = self.render_camera.render_camera.node.absolute_translation + # only showing active contacts + active_contacts = (x for x in cps if x.is_active) + for cp in active_contacts: + # red shows the contact distance + self.sim.get_debug_line_render().draw_transformed_line( + cp.position_on_b_in_ws, + cp.position_on_b_in_ws + + cp.contact_normal_on_b_in_ws * -cp.contact_distance, + red, + ) + # yellow shows the contact normal at a fixed length for visualization + self.sim.get_debug_line_render().draw_transformed_line( + cp.position_on_b_in_ws, + # + cp.contact_normal_on_b_in_ws * cp.contact_distance, + cp.position_on_b_in_ws + cp.contact_normal_on_b_in_ws * 0.1, + yellow, + ) + self.sim.get_debug_line_render().draw_circle( + translation=cp.position_on_b_in_ws, + radius=0.005, + color=yellow, + normal=camera_position - cp.position_on_b_in_ws, + ) + + def debug_draw(self): + """ + Additional draw commands to be called during draw_event. + """ + if self.debug_bullet_draw: + render_cam = self.render_camera.render_camera + proj_mat = render_cam.projection_matrix.__matmul__(render_cam.camera_matrix) + self.sim.physics_debug_draw(proj_mat) + if self.contact_debug_draw: + self.draw_contact_debug() + + # mouse raycast circle + white = mn.Color4(mn.Vector3(1.0), 1.0) + if self.mouse_cast_results is not None and self.mouse_cast_results.has_hits(): + m_ray = self.mouse_cast_results.ray + self.sim.get_debug_line_render().draw_circle( + translation=m_ray.origin + + m_ray.direction + * self.mouse_cast_results.hits[0].ray_distance + * m_ray.direction.length(), + radius=0.005, + color=white, + normal=self.mouse_cast_results.hits[0].normal, + ) + + if gt_raycast_results is not None and self.debug_draw_raycasts: + scene_bb = self.sim.get_active_scene_graph().get_root_node().cumulative_bb + inflated_scene_bb = scene_bb.scaled(mn.Vector3(1.25)) + inflated_scene_bb = mn.Range3D.from_center( + scene_bb.center(), inflated_scene_bb.size() / 2.0 + ) + self.sim.get_debug_line_render().draw_box( + inflated_scene_bb.min, inflated_scene_bb.max, white + ) + if self.sim.get_rigid_object_manager().get_num_objects() == 0: + self.collision_proxy_obj = ( + self.sim.get_rigid_object_manager().add_object_by_template_handle( + obj_temp_handle + ) + ) + self.collision_proxy_obj.motion_type = ( + habitat_sim.physics.MotionType.KINEMATIC + ) + + csa.debug_draw_raycast_results( + self.sim, gt_raycast_results, pr_raycast_results, seed=self.sample_seed + ) + + # draw test points + for side in test_points: + for p in side: + self.sim.get_debug_line_render().draw_circle( + translation=p, + radius=0.005, + color=mn.Color4.magenta(), + ) + + if self.debug_draw_receptacles and self.collision_proxy_obj is not None: + # parse any receptacles defined for the object + if len(self.object_receptacles) == 0: + source_template_file = ( + self.collision_proxy_obj.creation_attributes.file_directory + ) + user_attr = self.collision_proxy_obj.user_attributes + self.object_receptacles = ( + hab_receptacle.parse_receptacles_from_user_config( + user_attr, + parent_object_handle=self.collision_proxy_obj.handle, + parent_template_directory=source_template_file, + ) + ) + # draw any receptacles for the object + for rix, receptacle in enumerate(self.object_receptacles): + c = d3_40_colors_rgb[rix] + rec_color = mn.Vector3(c[0], c[1], c[2]) / 256.0 + receptacle.debug_draw(self.sim, color=mn.Color4(rec_color)) + + def draw_event( + self, + simulation_call: Optional[Callable] = None, + global_call: Optional[Callable] = None, + active_agent_id_and_sensor_name: Tuple[int, str] = (0, "color_sensor"), + ) -> None: + """ + Calls continuously to re-render frames and swap the two frame buffers + at a fixed rate. + """ + agent_acts_per_sec = self.fps + + mn.gl.default_framebuffer.clear( + mn.gl.FramebufferClear.COLOR | mn.gl.FramebufferClear.DEPTH + ) + + # Agent actions should occur at a fixed rate per second + self.time_since_last_simulation += Timer.prev_frame_duration + num_agent_actions: int = self.time_since_last_simulation * agent_acts_per_sec + self.move_and_look(int(num_agent_actions)) + + # Occasionally a frame will pass quicker than 1/60 seconds + if self.time_since_last_simulation >= 1.0 / self.fps: + if self.simulating or self.simulate_single_step: + self.sim.step_world(1.0 / self.fps) + self.simulate_single_step = False + if simulation_call is not None: + simulation_call() + if global_call is not None: + global_call() + + # reset time_since_last_simulation, accounting for potential overflow + self.time_since_last_simulation = math.fmod( + self.time_since_last_simulation, 1.0 / self.fps + ) + + keys = active_agent_id_and_sensor_name + + if self.enable_batch_renderer: + self.render_batch() + else: + self.sim._Simulator__sensors[keys[0]][keys[1]].draw_observation() + agent = self.sim.get_agent(keys[0]) + self.render_camera = agent.scene_node.node_sensor_suite.get(keys[1]) + self.debug_draw() + self.render_camera.render_target.blit_rgba_to_default() + + # draw CPU/GPU usage data and other info to the app window + mn.gl.default_framebuffer.bind() + self.draw_text(self.render_camera.specification()) + + self.swap_buffers() + Timer.next_frame() + self.redraw() + + def default_agent_config(self) -> habitat_sim.agent.AgentConfiguration: + """ + Set up our own agent and agent controls + """ + make_action_spec = habitat_sim.agent.ActionSpec + make_actuation_spec = habitat_sim.agent.ActuationSpec + MOVE, LOOK = 0.07, 1.5 + + # all of our possible actions' names + action_list = [ + "move_left", + "turn_left", + "move_right", + "turn_right", + "move_backward", + "look_up", + "move_forward", + "look_down", + "move_down", + "move_up", + ] + + action_space: Dict[str, habitat_sim.agent.ActionSpec] = {} + + # build our action space map + for action in action_list: + actuation_spec_amt = MOVE if "move" in action else LOOK + action_spec = make_action_spec( + action, make_actuation_spec(actuation_spec_amt) + ) + action_space[action] = action_spec + + sensor_spec: List[habitat_sim.sensor.SensorSpec] = self.cfg.agents[ + self.agent_id + ].sensor_specifications + + agent_config = habitat_sim.agent.AgentConfiguration( + height=1.5, + radius=0.1, + sensor_specifications=sensor_spec, + action_space=action_space, + body_type="cylinder", + ) + return agent_config + + def reconfigure_sim(self) -> None: + """ + Utilizes the current `self.sim_settings` to configure and set up a new + `habitat_sim.Simulator`, and then either starts a simulation instance, or replaces + the current simulator instance, reloading the most recently loaded scene + """ + # configure our sim_settings but then set the agent to our default + self.cfg = make_cfg(self.sim_settings) + self.cfg.metadata_mediator = mm + self.agent_id: int = self.sim_settings["default_agent"] + self.cfg.agents[self.agent_id] = self.default_agent_config() + + if self.enable_batch_renderer: + self.cfg.enable_batch_renderer = True + self.cfg.sim_cfg.create_renderer = False + self.cfg.sim_cfg.enable_gfx_replay_save = True + + if self.sim_settings["stage_requires_lighting"]: + logger.info("Setting synthetic lighting override for stage.") + self.cfg.sim_cfg.override_scene_light_defaults = True + self.cfg.sim_cfg.scene_light_setup = habitat_sim.gfx.DEFAULT_LIGHTING_KEY + + # create custom stage from object + if self.cfg.metadata_mediator is None: + self.cfg.metadata_mediator = habitat_sim.metadata.MetadataMediator() + self.cfg.metadata_mediator.active_dataset = self.sim_settings[ + "scene_dataset_config_file" + ] + if args.reorient_object: + obj_handle = ( + self.cfg.metadata_mediator.object_template_manager.get_template_handles( + args.target_object + )[0] + ) + fp_models_metadata_file = ( + "/home/alexclegg/Documents/dev/fphab/fpModels_metadata.csv" + ) + obj_orientations = csa.parse_object_orientations_from_metadata_csv( + fp_models_metadata_file + ) + csa.correct_object_orientations( + [obj_handle], obj_orientations, self.cfg.metadata_mediator + ) + + otm = self.cfg.metadata_mediator.object_template_manager + obj_template = otm.get_template_by_handle(obj_temp_handle) + obj_template.compute_COM_from_shape = False + obj_template.com = mn.Vector3(0) + otm.register_template(obj_template) + stm = self.cfg.metadata_mediator.stage_template_manager + stage_template_name = "obj_as_stage_template" + new_stage_template = stm.create_new_template(handle=stage_template_name) + new_stage_template.render_asset_handle = obj_template.render_asset_handle + new_stage_template.orient_up = obj_template.orient_up + new_stage_template.orient_front = obj_template.orient_front + + # margin must be 0 for snapping to work on overlapped gt/proxy + new_stage_template.margin = 0.0 + stm.register_template( + template=new_stage_template, specified_handle=stage_template_name + ) + self.cfg.sim_cfg.scene_id = stage_template_name + # visualize the object as its collision shape + obj_template.render_asset_handle = obj_template.collision_asset_handle + print(f"obj_template.render_asset_handle = {obj_template.render_asset_handle}") + print( + f"obj_template.collision_asset_handle = {obj_template.collision_asset_handle}" + ) + otm.register_template(obj_template) + + if self.sim is None: + self.tiled_sims = [] + for _i in range(self.num_env): + self.tiled_sims.append(habitat_sim.Simulator(self.cfg)) + self.sim = self.tiled_sims[0] + else: # edge case + for i in range(self.num_env): + if ( + self.tiled_sims[i].config.sim_cfg.scene_id + == self.cfg.sim_cfg.scene_id + ): + # we need to force a reset, so change the internal config scene name + self.tiled_sims[i].config.sim_cfg.scene_id = "NONE" + self.tiled_sims[i].reconfigure(self.cfg) + + # post reconfigure + self.default_agent = self.sim.get_agent(self.agent_id) + self.render_camera = self.default_agent.scene_node.node_sensor_suite.get( + "color_sensor" + ) + + # set sim_settings scene name as actual loaded scene + self.sim_settings["scene"] = self.sim.curr_scene_name + + # Initialize replay renderer + if self.enable_batch_renderer and self.replay_renderer is None: + self.replay_renderer_cfg = ReplayRendererConfiguration() + self.replay_renderer_cfg.num_environments = self.num_env + self.replay_renderer_cfg.standalone = ( + False # Context is owned by the GLFW window + ) + self.replay_renderer_cfg.sensor_specifications = self.cfg.agents[ + self.agent_id + ].sensor_specifications + self.replay_renderer_cfg.gpu_device_id = self.cfg.sim_cfg.gpu_device_id + self.replay_renderer_cfg.force_separate_semantic_scene_graph = False + self.replay_renderer_cfg.leave_context_with_background_renderer = False + self.replay_renderer = ReplayRenderer.create_batch_replay_renderer( + self.replay_renderer_cfg + ) + # Pre-load composite files + if sim_settings["composite_files"] is not None: + for composite_file in sim_settings["composite_files"]: + self.replay_renderer.preload_file(composite_file) + + otm = self.sim.metadata_mediator.object_template_manager + otm.load_configs("data/objects/ycb/configs/") + + Timer.start() + self.step = -1 + + def render_batch(self): + """ + This method updates the replay manager with the current state of environments and renders them. + """ + for i in range(self.num_env): + # Apply keyframe + keyframe = self.tiled_sims[i].gfx_replay_manager.extract_keyframe() + self.replay_renderer.set_environment_keyframe(i, keyframe) + # Copy sensor transforms + sensor_suite = self.tiled_sims[i]._sensors + for sensor_uuid, sensor in sensor_suite.items(): + transform = sensor._sensor_object.node.absolute_transformation() + self.replay_renderer.set_sensor_transform(i, sensor_uuid, transform) + # Render + self.replay_renderer.render(mn.gl.default_framebuffer) + + def move_and_look(self, repetitions: int) -> None: + """ + This method is called continuously with `self.draw_event` to monitor + any changes in the movement keys map `Dict[KeyEvent.key, Bool]`. + When a key in the map is set to `True` the corresponding action is taken. + """ + # avoids unnecessary updates to grabber's object position + if repetitions == 0: + return + + key = Application.KeyEvent.Key + agent = self.sim.agents[self.agent_id] + press: Dict[key.key, bool] = self.pressed + act: Dict[key.key, str] = self.key_to_action + + action_queue: List[str] = [act[k] for k, v in press.items() if v] + + for _ in range(int(repetitions)): + [agent.act(x) for x in action_queue] + + # update the grabber transform when our agent is moved + if self.mouse_grabber is not None: + # update location of grabbed object + self.update_grab_position(self.previous_mouse_point) + + def invert_gravity(self) -> None: + """ + Sets the gravity vector to the negative of it's previous value. This is + a good method for testing simulation functionality. + """ + gravity: mn.Vector3 = self.sim.get_gravity() * -1 + self.sim.set_gravity(gravity) + + def key_press_event(self, event: Application.KeyEvent) -> None: + """ + Handles `Application.KeyEvent` on a key press by performing the corresponding functions. + If the key pressed is part of the movement keys map `Dict[KeyEvent.key, Bool]`, then the + key will be set to False for the next `self.move_and_look()` to update the current actions. + """ + key = event.key + pressed = Application.KeyEvent.Key + mod = Application.InputEvent.Modifier + + shift_pressed = bool(event.modifiers & mod.SHIFT) + alt_pressed = bool(event.modifiers & mod.ALT) + # warning: ctrl doesn't always pass through with other key-presses + + if key == pressed.ESC: + event.accepted = True + self.exit_event(Application.ExitEvent) + return + + elif key == pressed.H: + self.print_help_text() + + elif key == pressed.TAB: + # NOTE: (+ALT) - reconfigure without cycling scenes + if not alt_pressed: + # cycle the active scene from the set available in MetadataMediator + inc = -1 if shift_pressed else 1 + scene_ids = self.sim.metadata_mediator.get_scene_handles() + cur_scene_index = 0 + if self.sim_settings["scene"] not in scene_ids: + matching_scenes = [ + (ix, x) + for ix, x in enumerate(scene_ids) + if self.sim_settings["scene"] in x + ] + if not matching_scenes: + logger.warning( + f"The current scene, '{self.sim_settings['scene']}', is not in the list, starting cycle at index 0." + ) + else: + cur_scene_index = matching_scenes[0][0] + else: + cur_scene_index = scene_ids.index(self.sim_settings["scene"]) + + next_scene_index = min( + max(cur_scene_index + inc, 0), len(scene_ids) - 1 + ) + self.sim_settings["scene"] = scene_ids[next_scene_index] + self.reconfigure_sim() + logger.info( + f"Reconfigured simulator for scene: {self.sim_settings['scene']}" + ) + + elif key == pressed.SPACE: + if not self.sim.config.sim_cfg.enable_physics: + logger.warn("Warning: physics was not enabled during setup") + else: + self.simulating = not self.simulating + logger.info(f"Command: physics simulating set to {self.simulating}") + + elif key == pressed.PERIOD: + if self.simulating: + logger.warn("Warning: physics simulation already running") + else: + self.simulate_single_step = True + logger.info("Command: physics step taken") + + elif key == pressed.COMMA: + self.debug_bullet_draw = not self.debug_bullet_draw + logger.info(f"Command: toggle Bullet debug draw: {self.debug_bullet_draw}") + + elif key == pressed.C: + if shift_pressed: + self.contact_debug_draw = not self.contact_debug_draw + logger.info( + f"Command: toggle contact debug draw: {self.contact_debug_draw}" + ) + else: + # perform a discrete collision detection pass and enable contact debug drawing to visualize the results + logger.info( + "Command: perform discrete collision detection and visualize active contacts." + ) + self.sim.perform_discrete_collision_detection() + self.contact_debug_draw = True + # TODO: add a nice log message with concise contact pair naming. + + elif key == pressed.O: + # move the object in/out of the frame + if self.collision_proxy_obj is not None: + if self.collision_proxy_obj.translation == mn.Vector3(0): + self.collision_proxy_obj.translation = mn.Vector3(100) + else: + self.collision_proxy_obj.translation = mn.Vector3(0) + + elif key == pressed.T: + if alt_pressed: + self.debug_draw_raycasts = not self.debug_draw_raycasts + print(f"Toggled self.debug_draw_raycasts: {self.debug_draw_raycasts}") + elif shift_pressed: + self.sample_seed -= 1 + else: + self.sample_seed += 1 + + event.accepted = True + return + # load URDF + fixed_base = alt_pressed + urdf_file_path = "" + if shift_pressed and self.cached_urdf: + urdf_file_path = self.cached_urdf + else: + urdf_file_path = input("Load URDF: provide a URDF filepath:").strip() + + if not urdf_file_path: + logger.warn("Load URDF: no input provided. Aborting.") + elif not urdf_file_path.endswith((".URDF", ".urdf")): + logger.warn("Load URDF: input is not a URDF. Aborting.") + elif os.path.exists(urdf_file_path): + self.cached_urdf = urdf_file_path + aom = self.sim.get_articulated_object_manager() + ao = aom.add_articulated_object_from_urdf( + urdf_file_path, fixed_base, 1.0, 1.0, True + ) + ao.translation = ( + self.default_agent.scene_node.transformation.transform_point( + [0.0, 1.0, -1.5] + ) + ) + else: + logger.warn("Load URDF: input file not found. Aborting.") + + elif key == pressed.M: + self.cycle_mouse_mode() + logger.info(f"Command: mouse mode set to {self.mouse_interaction}") + + elif key == pressed.V: + self.invert_gravity() + logger.info("Command: gravity inverted") + + elif key == pressed.N: + # (default) - toggle navmesh visualization + # NOTE: (+ALT) - re-sample the agent position on the NavMesh + # NOTE: (+SHIFT) - re-compute the NavMesh + if alt_pressed: + logger.info("Command: resample agent state from navmesh") + if self.sim.pathfinder.is_loaded: + new_agent_state = habitat_sim.AgentState() + new_agent_state.position = ( + self.sim.pathfinder.get_random_navigable_point() + ) + new_agent_state.rotation = quat_from_angle_axis( + self.sim.random.uniform_float(0, 2.0 * np.pi), + np.array([0, 1, 0]), + ) + self.default_agent.set_state(new_agent_state) + else: + logger.warning( + "NavMesh is not initialized. Cannot sample new agent state." + ) + elif shift_pressed: + logger.info("Command: recompute navmesh") + self.navmesh_config_and_recompute() + else: + if self.sim.pathfinder.is_loaded: + self.sim.navmesh_visualization = not self.sim.navmesh_visualization + logger.info("Command: toggle navmesh") + else: + logger.warn("Warning: recompute navmesh first") + + # update map of moving/looking keys which are currently pressed + if key in self.pressed: + self.pressed[key] = True + event.accepted = True + self.redraw() + + def key_release_event(self, event: Application.KeyEvent) -> None: + """ + Handles `Application.KeyEvent` on a key release. When a key is released, if it + is part of the movement keys map `Dict[KeyEvent.key, Bool]`, then the key will + be set to False for the next `self.move_and_look()` to update the current actions. + """ + key = event.key + + # update map of moving/looking keys which are currently pressed + if key in self.pressed: + self.pressed[key] = False + event.accepted = True + self.redraw() + + def mouse_move_event(self, event: Application.MouseMoveEvent) -> None: + """ + Handles `Application.MouseMoveEvent`. When in LOOK mode, enables the left + mouse button to steer the agent's facing direction. When in GRAB mode, + continues to update the grabber's object position with our agents position. + """ + + render_camera = self.render_camera.render_camera + ray = render_camera.unproject(self.get_mouse_position(event.position)) + self.mouse_cast_results = self.sim.cast_ray(ray=ray) + + button = Application.MouseMoveEvent.Buttons + # if interactive mode -> LOOK MODE + if event.buttons == button.LEFT and self.mouse_interaction == MouseMode.LOOK: + agent = self.sim.agents[self.agent_id] + delta = self.get_mouse_position(event.relative_position) / 2 + action = habitat_sim.agent.ObjectControls() + act_spec = habitat_sim.agent.ActuationSpec + + # left/right on agent scene node + action(agent.scene_node, "turn_right", act_spec(delta.x)) + + # up/down on cameras' scene nodes + action = habitat_sim.agent.ObjectControls() + sensors = list(self.default_agent.scene_node.subtree_sensors.values()) + [action(s.object, "look_down", act_spec(delta.y), False) for s in sensors] + + # if interactive mode is TRUE -> GRAB MODE + elif self.mouse_interaction == MouseMode.GRAB and self.mouse_grabber: + # update location of grabbed object + self.update_grab_position(self.get_mouse_position(event.position)) + + self.previous_mouse_point = self.get_mouse_position(event.position) + self.redraw() + event.accepted = True + + def construct_cylinder_object2( + self, cyl_radius: float = 0.04, cyl_height: float = 0.15 + ): + constructed_cyl_temp_name = "scaled_cyl_template" + otm = self.sim.metadata_mediator.object_template_manager + cyl_temp_handle = otm.get_synth_template_handles("cylinder")[0] + cyl_temp = otm.get_template_by_handle(cyl_temp_handle) + cyl_temp.scale = mn.Vector3(cyl_radius, cyl_height / 2.0, cyl_radius) + otm.register_template(cyl_temp, constructed_cyl_temp_name) + return constructed_cyl_temp_name + + def construct_cylinder_object( + self, cyl_radius: float = 0.04, cyl_height: float = 0.15 + ): + otm = self.sim.metadata_mediator.object_template_manager + cyl_temp_handle = otm.get_template_handles("chef")[0] + return cyl_temp_handle + + def mouse_press_event(self, event: Application.MouseEvent) -> None: + """ + Handles `Application.MouseEvent`. When in GRAB mode, click on + objects to drag their position. (right-click for fixed constraints) + """ + button = Application.MouseEvent.Button + physics_enabled = self.sim.get_physics_simulation_library() + + # if interactive mode is True -> GRAB MODE + if self.mouse_interaction == MouseMode.GRAB and physics_enabled: + render_camera = self.render_camera.render_camera + ray = render_camera.unproject(self.get_mouse_position(event.position)) + raycast_results = self.sim.cast_ray(ray=ray) + + if raycast_results.has_hits(): + hit_object, ao_link = -1, -1 + hit_info = raycast_results.hits[0] + + if hit_info.object_id >= 0: + # we hit an non-staged collision object + ro_mngr = self.sim.get_rigid_object_manager() + ao_mngr = self.sim.get_articulated_object_manager() + ao = ao_mngr.get_object_by_id(hit_info.object_id) + ro = ro_mngr.get_object_by_id(hit_info.object_id) + + if ro: + # if grabbed an object + hit_object = hit_info.object_id + object_pivot = ro.transformation.inverted().transform_point( + hit_info.point + ) + object_frame = ro.rotation.inverted() + elif ao: + # if grabbed the base link + hit_object = hit_info.object_id + object_pivot = ao.transformation.inverted().transform_point( + hit_info.point + ) + object_frame = ao.rotation.inverted() + else: + for ao_handle in ao_mngr.get_objects_by_handle_substring(): + ao = ao_mngr.get_object_by_handle(ao_handle) + link_to_obj_ids = ao.link_object_ids + + if hit_info.object_id in link_to_obj_ids: + # if we got a link + ao_link = link_to_obj_ids[hit_info.object_id] + object_pivot = ( + ao.get_link_scene_node(ao_link) + .transformation.inverted() + .transform_point(hit_info.point) + ) + object_frame = ao.get_link_scene_node( + ao_link + ).rotation.inverted() + hit_object = ao.object_id + break + # done checking for AO + + if hit_object >= 0: + node = self.default_agent.scene_node + constraint_settings = physics.RigidConstraintSettings() + + constraint_settings.object_id_a = hit_object + constraint_settings.link_id_a = ao_link + constraint_settings.pivot_a = object_pivot + constraint_settings.frame_a = ( + object_frame.to_matrix() @ node.rotation.to_matrix() + ) + constraint_settings.frame_b = node.rotation.to_matrix() + constraint_settings.pivot_b = hit_info.point + + # by default use a point 2 point constraint + if event.button == button.RIGHT: + constraint_settings.constraint_type = ( + physics.RigidConstraintType.Fixed + ) + + grip_depth = ( + hit_info.point - render_camera.node.absolute_translation + ).length() + + self.mouse_grabber = MouseGrabber( + constraint_settings, + grip_depth, + self.sim, + ) + else: + logger.warn("Oops, couldn't find the hit object. That's odd.") + # end if didn't hit the scene + # end has raycast hit + # end has physics enabled + elif ( + self.mouse_interaction == MouseMode.LOOK + and physics_enabled + and self.mouse_cast_results is not None + and self.mouse_cast_results.has_hits() + and event.button == button.RIGHT + ): + constructed_cyl_obj_handle = None + import random + + r = random.randint(0, 1) + if r == 0: + constructed_cyl_obj_handle = self.construct_cylinder_object() + else: + constructed_cyl_obj_handle = self.construct_cylinder_object2() + # try to place an object + if ( + mn.math.dot( + self.mouse_cast_results.hits[0].normal.normalized(), + mn.Vector3(0, 1, 0), + ) + > 0.5 + ): + rom = self.sim.get_rigid_object_manager() + cyl_test_obj = rom.add_object_by_template_handle( + constructed_cyl_obj_handle + ) + assert cyl_test_obj is not None + cyl_test_obj.translation = self.mouse_cast_results.hits[ + 0 + ].point + mn.Vector3(0, 0.04, 0) + success = snap_down( + self.sim, + cyl_test_obj, + support_obj_ids=[-1, self.collision_proxy_obj.object_id], + ) + print(success) + if not success: + rom.remove_object_by_handle(cyl_test_obj.handle) + + self.previous_mouse_point = self.get_mouse_position(event.position) + self.redraw() + event.accepted = True + + def mouse_scroll_event(self, event: Application.MouseScrollEvent) -> None: + """ + Handles `Application.MouseScrollEvent`. When in LOOK mode, enables camera + zooming (fine-grained zoom using shift) When in GRAB mode, adjusts the depth + of the grabber's object. (larger depth change rate using shift) + """ + scroll_mod_val = ( + event.offset.y + if abs(event.offset.y) > abs(event.offset.x) + else event.offset.x + ) + if not scroll_mod_val: + return + + # use shift to scale action response + shift_pressed = bool(event.modifiers & Application.InputEvent.Modifier.SHIFT) + alt_pressed = bool(event.modifiers & Application.InputEvent.Modifier.ALT) + ctrl_pressed = bool(event.modifiers & Application.InputEvent.Modifier.CTRL) + + # if interactive mode is False -> LOOK MODE + if self.mouse_interaction == MouseMode.LOOK: + # use shift for fine-grained zooming + mod_val = 1.01 if shift_pressed else 1.1 + mod = mod_val if scroll_mod_val > 0 else 1.0 / mod_val + cam = self.render_camera + cam.zoom(mod) + self.redraw() + + elif self.mouse_interaction == MouseMode.GRAB and self.mouse_grabber: + # adjust the depth + mod_val = 0.1 if shift_pressed else 0.01 + scroll_delta = scroll_mod_val * mod_val + if alt_pressed or ctrl_pressed: + # rotate the object's local constraint frame + agent_t = self.default_agent.scene_node.transformation_matrix() + # ALT - yaw + rotation_axis = agent_t.transform_vector(mn.Vector3(0, 1, 0)) + if alt_pressed and ctrl_pressed: + # ALT+CTRL - roll + rotation_axis = agent_t.transform_vector(mn.Vector3(0, 0, -1)) + elif ctrl_pressed: + # CTRL - pitch + rotation_axis = agent_t.transform_vector(mn.Vector3(1, 0, 0)) + self.mouse_grabber.rotate_local_frame_by_global_angle_axis( + rotation_axis, mn.Rad(scroll_delta) + ) + else: + # update location of grabbed object + self.mouse_grabber.grip_depth += scroll_delta + self.update_grab_position(self.get_mouse_position(event.position)) + self.redraw() + event.accepted = True + + def mouse_release_event(self, event: Application.MouseEvent) -> None: + """ + Release any existing constraints. + """ + del self.mouse_grabber + self.mouse_grabber = None + event.accepted = True + + def update_grab_position(self, point: mn.Vector2i) -> None: + """ + Accepts a point derived from a mouse click event and updates the + transform of the mouse grabber. + """ + # check mouse grabber + if not self.mouse_grabber: + return + + render_camera = self.render_camera.render_camera + ray = render_camera.unproject(point) + + rotation: mn.Matrix3x3 = self.default_agent.scene_node.rotation.to_matrix() + translation: mn.Vector3 = ( + render_camera.node.absolute_translation + + ray.direction * self.mouse_grabber.grip_depth + ) + self.mouse_grabber.update_transform(mn.Matrix4.from_(rotation, translation)) + + def get_mouse_position(self, mouse_event_position: mn.Vector2i) -> mn.Vector2i: + """ + This function will get a screen-space mouse position appropriately + scaled based on framebuffer size and window size. Generally these would be + the same value, but on certain HiDPI displays (Retina displays) they may be + different. + """ + scaling = mn.Vector2i(self.framebuffer_size) / mn.Vector2i(self.window_size) + return mouse_event_position * scaling + + def cycle_mouse_mode(self) -> None: + """ + This method defines how to cycle through the mouse mode. + """ + if self.mouse_interaction == MouseMode.LOOK: + self.mouse_interaction = MouseMode.GRAB + elif self.mouse_interaction == MouseMode.GRAB: + self.mouse_interaction = MouseMode.LOOK + + def navmesh_config_and_recompute(self) -> None: + """ + This method is setup to be overridden in for setting config accessibility + in inherited classes. + """ + self.navmesh_settings = habitat_sim.NavMeshSettings() + self.navmesh_settings.set_defaults() + self.navmesh_settings.agent_height = self.cfg.agents[self.agent_id].height + self.navmesh_settings.agent_radius = self.cfg.agents[self.agent_id].radius + self.navmesh_settings.include_static_objects = True + self.sim.recompute_navmesh( + self.sim.pathfinder, + self.navmesh_settings, + ) + + def exit_event(self, event: Application.ExitEvent): + """ + Overrides exit_event to properly close the Simulator before exiting the + application. + """ + for i in range(self.num_env): + self.tiled_sims[i].close(destroy=True) + event.accepted = True + exit(0) + + def draw_text(self, sensor_spec): + self.shader.bind_vector_texture(self.glyph_cache.texture) + self.shader.transformation_projection_matrix = self.window_text_transform + self.shader.color = [1.0, 1.0, 1.0] + + sensor_type_string = str(sensor_spec.sensor_type.name) + sensor_subtype_string = str(sensor_spec.sensor_subtype.name) + if self.mouse_interaction == MouseMode.LOOK: + mouse_mode_string = "LOOK" + elif self.mouse_interaction == MouseMode.GRAB: + mouse_mode_string = "GRAB" + self.window_text.render( + f""" +{self.fps} FPS +Sensor Type: {sensor_type_string} +Sensor Subtype: {sensor_subtype_string} +Mouse Interaction Mode: {mouse_mode_string} + """ + ) + self.shader.draw(self.window_text.mesh) + + def print_help_text(self) -> None: + """ + Print the Key Command help text. + """ + logger.info( + """ +===================================================== +Welcome to the Habitat-sim Python Viewer application! +===================================================== +Mouse Functions ('m' to toggle mode): +---------------- +In LOOK mode (default): + LEFT: + Click and drag to rotate the agent and look up/down. + WHEEL: + Modify orthographic camera zoom/perspective camera FOV (+SHIFT for fine grained control) + +In GRAB mode (with 'enable-physics'): + LEFT: + Click and drag to pickup and move an object with a point-to-point constraint (e.g. ball joint). + RIGHT: + Click and drag to pickup and move an object with a fixed frame constraint. + WHEEL (with picked object): + default - Pull gripped object closer or push it away. + (+ALT) rotate object fixed constraint frame (yaw) + (+CTRL) rotate object fixed constraint frame (pitch) + (+ALT+CTRL) rotate object fixed constraint frame (roll) + (+SHIFT) amplify scroll magnitude + + +Key Commands: +------------- + esc: Exit the application. + 'h': Display this help message. + 'm': Cycle mouse interaction modes. + + Agent Controls: + 'wasd': Move the agent's body forward/backward and left/right. + 'zx': Move the agent's body up/down. + arrow keys: Turn the agent's body left/right and camera look up/down. + + Utilities: + 'r': Reset the simulator with the most recently loaded scene. + 'n': Show/hide NavMesh wireframe. + (+SHIFT) Recompute NavMesh with default settings. + (+ALT) Re-sample the agent(camera)'s position and orientation from the NavMesh. + ',': Render a Bullet collision shape debug wireframe overlay (white=active, green=sleeping, blue=wants sleeping, red=can't sleep). + 'c': Run a discrete collision detection pass and render a debug wireframe overlay showing active contact points and normals (yellow=fixed length normals, red=collision distances). + (+SHIFT) Toggle the contact point debug render overlay on/off. + + Object Interactions: + SPACE: Toggle physics simulation on/off. + '.': Take a single simulation step if not simulating continuously. + 'v': (physics) Invert gravity. + 't': Load URDF from filepath + (+SHIFT) quick re-load the previously specified URDF + (+ALT) load the URDF with fixed base +===================================================== +""" + ) + + +class MouseMode(Enum): + LOOK = 0 + GRAB = 1 + MOTION = 2 + + +class MouseGrabber: + """ + Create a MouseGrabber from RigidConstraintSettings to manipulate objects. + """ + + def __init__( + self, + settings: physics.RigidConstraintSettings, + grip_depth: float, + sim: habitat_sim.simulator.Simulator, + ) -> None: + self.settings = settings + self.simulator = sim + + # defines distance of the grip point from the camera for pivot updates + self.grip_depth = grip_depth + self.constraint_id = sim.create_rigid_constraint(settings) + + def __del__(self): + self.remove_constraint() + + def remove_constraint(self) -> None: + """ + Remove a rigid constraint by id. + """ + self.simulator.remove_rigid_constraint(self.constraint_id) + + def updatePivot(self, pos: mn.Vector3) -> None: + self.settings.pivot_b = pos + self.simulator.update_rigid_constraint(self.constraint_id, self.settings) + + def update_frame(self, frame: mn.Matrix3x3) -> None: + self.settings.frame_b = frame + self.simulator.update_rigid_constraint(self.constraint_id, self.settings) + + def update_transform(self, transform: mn.Matrix4) -> None: + self.settings.frame_b = transform.rotation() + self.settings.pivot_b = transform.translation + self.simulator.update_rigid_constraint(self.constraint_id, self.settings) + + def rotate_local_frame_by_global_angle_axis( + self, axis: mn.Vector3, angle: mn.Rad + ) -> None: + """rotate the object's local constraint frame with a global angle axis input.""" + object_transform = mn.Matrix4() + rom = self.simulator.get_rigid_object_manager() + aom = self.simulator.get_articulated_object_manager() + if rom.get_library_has_id(self.settings.object_id_a): + object_transform = rom.get_object_by_id( + self.settings.object_id_a + ).transformation + else: + # must be an ao + object_transform = ( + aom.get_object_by_id(self.settings.object_id_a) + .get_link_scene_node(self.settings.link_id_a) + .transformation + ) + local_axis = object_transform.inverted().transform_vector(axis) + R = mn.Matrix4.rotation(angle, local_axis.normalized()) + self.settings.frame_a = R.rotation().__matmul__(self.settings.frame_a) + self.simulator.update_rigid_constraint(self.constraint_id, self.settings) + + +class Timer: + """ + Timer class used to keep track of time between buffer swaps + and guide the display frame rate. + """ + + start_time = 0.0 + prev_frame_time = 0.0 + prev_frame_duration = 0.0 + running = False + + @staticmethod + def start() -> None: + """ + Starts timer and resets previous frame time to the start time. + """ + Timer.running = True + Timer.start_time = time.time() + Timer.prev_frame_time = Timer.start_time + Timer.prev_frame_duration = 0.0 + + @staticmethod + def stop() -> None: + """ + Stops timer and erases any previous time data, resetting the timer. + """ + Timer.running = False + Timer.start_time = 0.0 + Timer.prev_frame_time = 0.0 + Timer.prev_frame_duration = 0.0 + + @staticmethod + def next_frame() -> None: + """ + Records previous frame duration and updates the previous frame timestamp + to the current time. If the timer is not currently running, perform nothing. + """ + if not Timer.running: + return + Timer.prev_frame_duration = time.time() - Timer.prev_frame_time + Timer.prev_frame_time = time.time() + + +if __name__ == "__main__": + import argparse + + parser = argparse.ArgumentParser() + + # optional arguments + parser.add_argument( + "--target-object", + type=str, + help="object file to load.", + ) + parser.add_argument( + "--col-obj", + default=None, + type=str, + help="Collision object file to use.", + ) + parser.add_argument( + "--dataset", + default="./data/objects/ycb/ycb.scene_dataset_config.json", + type=str, + metavar="DATASET", + help='dataset configuration file to use (default: "./data/objects/ycb/ycb.scene_dataset_config.json")', + ) + parser.add_argument( + "--disable-physics", + action="store_true", + help="disable physics simulation (default: False)", + ) + parser.add_argument( + "--reorient-object", + action="store_true", + help="reorient the object based on the values in the config file.", + ) + parser.add_argument( + "--stage-requires-lighting", + action="store_true", + help="Override configured lighting to use synthetic lighting for the stage.", + ) + parser.add_argument( + "--enable-batch-renderer", + action="store_true", + help="Enable batch rendering mode. The number of concurrent environments is specified with the num-environments parameter.", + ) + parser.add_argument( + "--num-environments", + default=1, + type=int, + help="Number of concurrent environments to batch render. Note that only the first environment simulates physics and can be controlled.", + ) + parser.add_argument( + "--composite-files", + type=str, + nargs="*", + help="Composite files that the batch renderer will use in-place of simulation assets to improve memory usage and performance. If none is specified, the original scene files will be loaded from disk.", + ) + parser.add_argument( + "--width", + default=800, + type=int, + help="Horizontal resolution of the window.", + ) + parser.add_argument( + "--height", + default=600, + type=int, + help="Vertical resolution of the window.", + ) + + args = parser.parse_args() + + if args.num_environments < 1: + parser.error("num-environments must be a positive non-zero integer.") + if args.width < 1: + parser.error("width must be a positive non-zero integer.") + if args.height < 1: + parser.error("height must be a positive non-zero integer.") + + # Setting up sim_settings + sim_settings: Dict[str, Any] = default_sim_settings + # sim_settings["scene"] = args.target_object + sim_settings["scene"] = "NONE" + sim_settings["scene_dataset_config_file"] = args.dataset + sim_settings["enable_physics"] = not args.disable_physics + sim_settings["stage_requires_lighting"] = args.stage_requires_lighting + sim_settings["enable_batch_renderer"] = args.enable_batch_renderer + sim_settings["num_environments"] = args.num_environments + sim_settings["composite_files"] = args.composite_files + sim_settings["window_width"] = args.width + sim_settings["window_height"] = args.height + sim_settings["clear_color"] = mn.Color4.magenta() + + obj_name = args.target_object + + # load JSON once instead of repeating + mm = habitat_sim.metadata.MetadataMediator() + mm.active_dataset = sim_settings["scene_dataset_config_file"] + + obj_temp_handle = mm.object_template_manager.get_file_template_handles(obj_name)[0] + + # set a custom collision asset + if args.col_obj is not None: + obj_temp = mm.object_template_manager.get_template_by_handle(obj_temp_handle) + obj_temp.collision_asset_handle = args.col_obj + mm.object_template_manager.register_template(obj_temp) + + cpo = csa.CollisionProxyOptimizer(sim_settings, None, mm) + cpo.setup_obj_gt(obj_temp_handle) + cpo.compute_proxy_metrics(obj_temp_handle) + # setup globals for debug drawing + test_points = cpo.gt_data[obj_temp_handle]["test_points"] + pr_raycast_results = cpo.gt_data[obj_temp_handle]["raycasts"]["pr0"] + gt_raycast_results = cpo.gt_data[obj_temp_handle]["raycasts"]["gt"] + + # start the application + HabitatSimInteractiveViewer(sim_settings, mm).exec() diff --git a/examples/spot_viewer.py b/examples/spot_viewer.py new file mode 100644 index 0000000000..0a8d28c85a --- /dev/null +++ b/examples/spot_viewer.py @@ -0,0 +1,1546 @@ +# Copyright (c) Meta Platforms, Inc. and its affiliates. +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import ctypes +import math +import os +import string +import sys +import time +from enum import Enum +from typing import Any, Callable, Dict, List, Optional, Tuple + +flags = sys.getdlopenflags() +sys.setdlopenflags(flags | ctypes.RTLD_GLOBAL) + +import habitat.articulated_agents.robots.spot_robot as spot_robot +import habitat.sims.habitat_simulator.sim_utilities as sutils +import magnum as mn +import numpy as np +from habitat.datasets.rearrange.navmesh_utils import get_largest_island_index +from magnum import shaders, text +from magnum.platform.glfw import Application +from omegaconf import DictConfig + +import habitat_sim +from habitat_sim import ReplayRenderer, ReplayRendererConfiguration +from habitat_sim.logging import LoggingContext, logger +from habitat_sim.utils.settings import default_sim_settings, make_cfg + +SPOT_DIR = "data/robots/hab_spot_arm/urdf/hab_spot_arm.urdf" +if not os.path.isfile(SPOT_DIR): + # support other layout + SPOT_DIR = "data/scene_datasets/robots/hab_spot_arm/urdf/hab_spot_arm.urdf" + + +# Describe edit type +class EditMode(Enum): + MOVE = 0 + ROTATE = 1 + NUM_VALS = 2 + + +EDIT_MODE_NAMES = ["Move object", "Rotate object"] + + +# Describe edit distance values +class DistanceMode(Enum): + TINY = 0 + VERY_SMALL = 1 + SMALL = 2 + MEDIUM = 3 + LARGE = 4 + HUGE = 5 + NUM_VALS = 6 + + +# distance values in m +DISTANCE_MODE_VALS = [0.001, 0.01, 0.02, 0.05, 0.1, 0.5] +# angle value multipliers (in degrees) - multiplied by conversion +ROTATION_MULT_VALS = [1.0, 10.0, 30.0, 45.0, 60.0, 90.0] +# 1 radian +BASE_EDIT_ROT_AMT = math.pi / 180.0 + + +class ExtractedBaseVelNonCylinderAction: + def __init__(self, sim, spot): + self._sim = sim + self.spot = spot + self.base_vel_ctrl = habitat_sim.physics.VelocityControl() + self.base_vel_ctrl.controlling_lin_vel = True + self.base_vel_ctrl.lin_vel_is_local = True + self.base_vel_ctrl.controlling_ang_vel = True + self.base_vel_ctrl.ang_vel_is_local = True + self._allow_dyn_slide = True + self._allow_back = True + self._longitudinal_lin_speed = 10.0 + self._lateral_lin_speed = 10.0 + self._ang_speed = 10.0 + self._navmesh_offset = [[0.0, 0.0], [0.25, 0.0], [-0.25, 0.0]] + self._enable_lateral_move = True + self._collision_threshold = 1e-5 + + def collision_check(self, trans, target_trans, target_rigid_state, compute_sliding): + """ + trans: the transformation of the current location of the robot + target_trans: the transformation of the target location of the robot given the center original Navmesh + target_rigid_state: the target state of the robot given the center original Navmesh + compute_sliding: if we want to compute sliding or not + """ + # Get the offset positions + num_check_cylinder = len(self._navmesh_offset) + nav_pos_3d = [np.array([xz[0], 0.0, xz[1]]) for xz in self._navmesh_offset] + cur_pos = [trans.transform_point(xyz) for xyz in nav_pos_3d] + goal_pos = [target_trans.transform_point(xyz) for xyz in nav_pos_3d] + + # For step filter of offset positions + end_pos = [] + for i in range(num_check_cylinder): + pos = self._sim.step_filter(cur_pos[i], goal_pos[i]) + # Sanitize the height + pos[1] = 0.0 + cur_pos[i][1] = 0.0 + goal_pos[i][1] = 0.0 + end_pos.append(pos) + + # Planar move distance clamped by NavMesh + move = [] + for i in range(num_check_cylinder): + move.append((end_pos[i] - goal_pos[i]).length()) + + # For detection of linear or angualr velocities + # There is a collision if the difference between the clamped NavMesh position and target position is too great for any point. + diff = len([v for v in move if v > self._collision_threshold]) + + if diff > 0: + # Wrap the move direction if we use sliding + # Find the largest diff moving direction, which means that there is a collision in that cylinder + if compute_sliding: + max_idx = np.argmax(move) + move_vec = end_pos[max_idx] - cur_pos[max_idx] + new_end_pos = trans.translation + move_vec + return True, mn.Matrix4.from_( + target_rigid_state.rotation.to_matrix(), new_end_pos + ) + return True, trans + else: + return False, target_trans + + def update_base(self, if_rotation): + """ + Update the base of the robot + if_rotation: if the robot is rotating or not + """ + # Get the control frequency + ctrl_freq = 60 + # Get the current transformation + trans = self.spot.sim_obj.transformation + # Get the current rigid state + rigid_state = habitat_sim.RigidState( + mn.Quaternion.from_matrix(trans.rotation()), trans.translation + ) + # Integrate to get target rigid state + target_rigid_state = self.base_vel_ctrl.integrate_transform( + 1 / ctrl_freq, rigid_state + ) + # Get the traget transformation based on the target rigid state + target_trans = mn.Matrix4.from_( + target_rigid_state.rotation.to_matrix(), + target_rigid_state.translation, + ) + # We do sliding only if we allow the robot to do sliding and current + # robot is not rotating + compute_sliding = self._allow_dyn_slide and not if_rotation + # Check if there is a collision + did_coll, new_target_trans = self.collision_check( + trans, target_trans, target_rigid_state, compute_sliding + ) + # Update the base + self.spot.sim_obj.transformation = new_target_trans + + if self.spot._base_type == "leg": + # Fix the leg joints + self.spot.leg_joint_pos = self.spot.params.leg_init_params + + def step(self, forward, lateral, angular): + """ + provide forward, lateral, and angular velocities as [-1,1]. + """ + longitudinal_lin_vel = forward + lateral_lin_vel = lateral + ang_vel = angular + longitudinal_lin_vel = ( + np.clip(longitudinal_lin_vel, -1, 1) * self._longitudinal_lin_speed + ) + lateral_lin_vel = np.clip(lateral_lin_vel, -1, 1) * self._lateral_lin_speed + ang_vel = np.clip(ang_vel, -1, 1) * self._ang_speed + if not self._allow_back: + longitudinal_lin_vel = np.maximum(longitudinal_lin_vel, 0) + + self.base_vel_ctrl.linear_velocity = mn.Vector3( + longitudinal_lin_vel, 0, -lateral_lin_vel + ) + self.base_vel_ctrl.angular_velocity = mn.Vector3(0, ang_vel, 0) + + if longitudinal_lin_vel != 0.0 or lateral_lin_vel != 0.0 or ang_vel != 0.0: + self.update_base(ang_vel != 0.0) + + +def recompute_ao_bbs(ao: habitat_sim.physics.ManagedArticulatedObject) -> None: + """ + Recomputes the link SceneNode bounding boxes for all ao links. + NOTE: Gets around an observed loading bug. Call before trying to peek an AO. + """ + for link_ix in range(-1, ao.num_links): + link_node = ao.get_link_scene_node(link_ix) + link_node.compute_cumulative_bb() + + +class HabitatSimInteractiveViewer(Application): + # the maximum number of chars displayable in the app window + # using the magnum text module. These chars are used to + # display the CPU/GPU usage data + MAX_DISPLAY_TEXT_CHARS = 256 + + # how much to displace window text relative to the center of the + # app window (e.g if you want the display text in the top left of + # the app window, you will displace the text + # window width * -TEXT_DELTA_FROM_CENTER in the x axis and + # window height * TEXT_DELTA_FROM_CENTER in the y axis, as the text + # position defaults to the middle of the app window) + TEXT_DELTA_FROM_CENTER = 0.49 + + # font size of the magnum in-window display text that displays + # CPU and GPU usage info + DISPLAY_FONT_SIZE = 16.0 + + def __init__(self, sim_settings: Dict[str, Any]) -> None: + self.sim_settings: Dict[str:Any] = sim_settings + + self.enable_batch_renderer: bool = self.sim_settings["enable_batch_renderer"] + self.num_env: int = ( + self.sim_settings["num_environments"] if self.enable_batch_renderer else 1 + ) + + # Compute environment camera resolution based on the number of environments to render in the window. + window_size: mn.Vector2 = ( + self.sim_settings["window_width"], + self.sim_settings["window_height"], + ) + + configuration = self.Configuration() + configuration.title = "Habitat Sim Interactive Viewer" + configuration.size = window_size + Application.__init__(self, configuration) + self.fps: float = 60.0 + + # Compute environment camera resolution based on the number of environments to render in the window. + grid_size: mn.Vector2i = ReplayRenderer.environment_grid_size(self.num_env) + camera_resolution: mn.Vector2 = mn.Vector2(self.framebuffer_size) / mn.Vector2( + grid_size + ) + self.sim_settings["width"] = camera_resolution[0] + self.sim_settings["height"] = camera_resolution[1] + + # draw Bullet debug line visualizations (e.g. collision meshes) + self.debug_bullet_draw = False + # draw active contact point debug line visualizations + self.contact_debug_draw = False + # cache most recently loaded URDF file for quick-reload + self.cached_urdf = "" + + # set up our movement map + key = Application.KeyEvent.Key + self.pressed = { + key.UP: False, + key.DOWN: False, + key.LEFT: False, + key.RIGHT: False, + key.A: False, + key.D: False, + key.S: False, + key.W: False, + key.X: False, + key.Z: False, + key.Q: False, + key.E: False, + } + + # Load a TrueTypeFont plugin and open the font file + self.display_font = text.FontManager().load_and_instantiate("TrueTypeFont") + relative_path_to_font = "../data/fonts/ProggyClean.ttf" + self.display_font.open_file( + os.path.join(os.path.dirname(__file__), relative_path_to_font), + 13, + ) + + # Glyphs we need to render everything + self.glyph_cache = text.GlyphCache(mn.Vector2i(256)) + self.display_font.fill_glyph_cache( + self.glyph_cache, + string.ascii_lowercase + + string.ascii_uppercase + + string.digits + + ":-_+,.! %ยต", + ) + + # magnum text object that displays CPU/GPU usage data in the app window + self.window_text = text.Renderer2D( + self.display_font, + self.glyph_cache, + HabitatSimInteractiveViewer.DISPLAY_FONT_SIZE, + text.Alignment.TOP_LEFT, + ) + self.window_text.reserve(HabitatSimInteractiveViewer.MAX_DISPLAY_TEXT_CHARS) + + # text object transform in window space is Projection matrix times Translation Matrix + # put text in top left of window + self.window_text_transform = mn.Matrix3.projection( + self.framebuffer_size + ) @ mn.Matrix3.translation( + mn.Vector2(self.framebuffer_size) + * mn.Vector2( + -HabitatSimInteractiveViewer.TEXT_DELTA_FROM_CENTER, + HabitatSimInteractiveViewer.TEXT_DELTA_FROM_CENTER, + ) + ) + self.shader = shaders.VectorGL2D() + + # make magnum text background transparent + mn.gl.Renderer.enable(mn.gl.Renderer.Feature.BLENDING) + mn.gl.Renderer.set_blend_function( + mn.gl.Renderer.BlendFunction.ONE, + mn.gl.Renderer.BlendFunction.ONE_MINUS_SOURCE_ALPHA, + ) + mn.gl.Renderer.set_blend_equation( + mn.gl.Renderer.BlendEquation.ADD, mn.gl.Renderer.BlendEquation.ADD + ) + + # variables that track app data and CPU/GPU usage + self.num_frames_to_track = 60 + + # Editing + # Edit mode + self.curr_edit_mode = EditMode.MOVE + # Edit distance/amount + self.curr_edit_multiplier = DistanceMode.VERY_SMALL + + # Initialize base edit changes + self.set_edit_vals() + + self.previous_mouse_point = None + + # toggle physics simulation on/off + self.simulating = True + + # toggle a single simulation step at the next opportunity if not + # simulating continuously. + self.simulate_single_step = False + + self.spot = None + self.spot_action = None + self.spot_forward = 0 + self.spot_lateral = 0 + self.spot_angular = 0 + self.camera_distance = 2.0 + self.camera_angles = mn.Vector2() + + # object selection and manipulation interface + self.selected_object = None + self.selected_object_orig_transform = mn.Matrix4().identity_init() + self.last_hit_details = None + # cache modified states of any objects moved by the interface. + self.modified_objects_buffer: Dict[ + habitat_sim.physics.ManagedRigidObject, mn.Matrix4 + ] = {} + self.removed_clutter = [] + + self.navmesh_dirty = False + self.removed_objects_debug_frames = [] + + # configure our simulator + self.cfg: Optional[habitat_sim.simulator.Configuration] = None + self.sim: Optional[habitat_sim.simulator.Simulator] = None + self.tiled_sims: list[habitat_sim.simulator.Simulator] = None + self.replay_renderer_cfg: Optional[ReplayRendererConfiguration] = None + self.replay_renderer: Optional[ReplayRenderer] = None + self.reconfigure_sim() + + # compute NavMesh if not already loaded by the scene. + if self.cfg.sim_cfg.scene_id.lower() != "none": + self.navmesh_config_and_recompute() + + self.place_spot() + + self.time_since_last_simulation = 0.0 + LoggingContext.reinitialize_from_env() + logger.setLevel("INFO") + self.print_help_text() + + def set_edit_vals(self): + # Set current scene object edit values for translation and rotation + # 1 cm * multiplier + self.edit_translation_dist = DISTANCE_MODE_VALS[self.curr_edit_multiplier.value] + # 1 radian * multiplier + self.edit_rotation_amt = ( + BASE_EDIT_ROT_AMT * ROTATION_MULT_VALS[self.curr_edit_multiplier.value] + ) + + def draw_removed_objects_debug_frames(self): + """ + Draw debug frames for all the recently removed objects. + """ + for trans, aabb in self.removed_objects_debug_frames: + dblr = self.sim.get_debug_line_render() + dblr.push_transform(trans) + dblr.draw_box(aabb.min, aabb.max, mn.Color4.red()) + dblr.pop_transform() + + def remove_outdoor_objects(self): + """ + Check all object instance and remove those which are marked outdoors. + """ + self.removed_objects_debug_frames = [] + rom = self.sim.get_rigid_object_manager() + for obj in rom.get_objects_by_handle_substring().values(): + if self.obj_is_outdoor(obj): + self.removed_objects_debug_frames.append( + (obj.transformation, obj.root_scene_node.cumulative_bb) + ) + rom.remove_object_by_id(obj.object_id) + + def obj_is_outdoor(self, obj): + """ + Check if an object is outdoors or not by raycasting upwards. + """ + up = mn.Vector3(0, 1.0, 0) + ray_results = self.sim.cast_ray(habitat_sim.geo.Ray(obj.translation, up)) + if ray_results.has_hits(): + for hit in ray_results.hits: + if hit.object_id == obj.object_id: + continue + return False + + # no hits, so outdoors + return True + + def place_spot(self): + if self.sim.pathfinder.is_loaded: + largest_island_ix = get_largest_island_index( + pathfinder=self.sim.pathfinder, + sim=self.sim, + allow_outdoor=False, + ) + print(f"Largest indoor island index = {largest_island_ix}") + valid_spot_point = None + max_attempts = 1000 + attempt = 0 + while valid_spot_point is None and attempt < max_attempts: + spot_point = self.sim.pathfinder.get_random_navigable_point( + island_index=largest_island_ix + ) + if self.sim.pathfinder.distance_to_closest_obstacle(spot_point) >= 0.25: + valid_spot_point = spot_point + attempt += 1 + if valid_spot_point is not None: + self.spot.base_pos = valid_spot_point + + def clear_furniture_joint_states(self): + """ + Clear all furniture object joint states. + """ + for ao in ( + self.sim.get_articulated_object_manager() + .get_objects_by_handle_substring() + .values() + ): + # ignore the robot + if "hab_spot" not in ao.handle: + j_pos = ao.joint_positions + ao.joint_positions = [0.0 for _ in range(len(j_pos))] + j_vel = ao.joint_velocities + ao.joint_velocities = [0.0 for _ in range(len(j_vel))] + + def draw_contact_debug(self): + """ + This method is called to render a debug line overlay displaying active contact points and normals. + Yellow lines show the contact distance along the normal and red lines show the contact normal at a fixed length. + """ + yellow = mn.Color4.yellow() + red = mn.Color4.red() + cps = self.sim.get_physics_contact_points() + self.sim.get_debug_line_render().set_line_width(1.5) + camera_position = self.render_camera.render_camera.node.absolute_translation + # only showing active contacts + active_contacts = (x for x in cps if x.is_active) + for cp in active_contacts: + # red shows the contact distance + self.sim.get_debug_line_render().draw_transformed_line( + cp.position_on_b_in_ws, + cp.position_on_b_in_ws + + cp.contact_normal_on_b_in_ws * -cp.contact_distance, + red, + ) + # yellow shows the contact normal at a fixed length for visualization + self.sim.get_debug_line_render().draw_transformed_line( + cp.position_on_b_in_ws, + # + cp.contact_normal_on_b_in_ws * cp.contact_distance, + cp.position_on_b_in_ws + cp.contact_normal_on_b_in_ws * 0.1, + yellow, + ) + self.sim.get_debug_line_render().draw_circle( + translation=cp.position_on_b_in_ws, + radius=0.005, + color=yellow, + normal=camera_position - cp.position_on_b_in_ws, + ) + + def debug_draw(self): + """ + Additional draw commands to be called during draw_event. + """ + if self.debug_bullet_draw: + render_cam = self.render_camera.render_camera + proj_mat = render_cam.projection_matrix.__matmul__(render_cam.camera_matrix) + self.sim.physics_debug_draw(proj_mat) + if self.contact_debug_draw: + self.draw_contact_debug() + if self.last_hit_details is not None: + self.sim.get_debug_line_render().draw_circle( + translation=self.last_hit_details.point, + radius=0.02, + normal=self.last_hit_details.normal, + color=mn.Color4.yellow(), + num_segments=12, + ) + if self.selected_object is not None: + aabb = None + if isinstance( + self.selected_object, habitat_sim.physics.ManagedBulletRigidObject + ): + aabb = self.selected_object.collision_shape_aabb + else: + aabb = sutils.get_ao_root_bb(self.selected_object) + dblr = self.sim.get_debug_line_render() + dblr.push_transform(self.selected_object.transformation) + dblr.draw_box(aabb.min, aabb.max, mn.Color4.magenta()) + dblr.pop_transform() + + ot = self.selected_object.translation + # draw global coordinate axis + dblr.draw_transformed_line( + ot - mn.Vector3.x_axis(), ot + mn.Vector3.x_axis(), mn.Color4.red() + ) + dblr.draw_transformed_line( + ot - mn.Vector3.y_axis(), ot + mn.Vector3.y_axis(), mn.Color4.green() + ) + dblr.draw_transformed_line( + ot - mn.Vector3.z_axis(), ot + mn.Vector3.z_axis(), mn.Color4.blue() + ) + dblr.draw_circle( + ot + mn.Vector3.x_axis() * 0.95, + radius=0.05, + color=mn.Color4.red(), + normal=mn.Vector3.x_axis(), + ) + dblr.draw_circle( + ot + mn.Vector3.y_axis() * 0.95, + radius=0.05, + color=mn.Color4.green(), + normal=mn.Vector3.y_axis(), + ) + dblr.draw_circle( + ot + mn.Vector3.z_axis() * 0.95, + radius=0.05, + color=mn.Color4.blue(), + normal=mn.Vector3.z_axis(), + ) + self.draw_removed_objects_debug_frames() + + def draw_event( + self, + simulation_call: Optional[Callable] = None, + global_call: Optional[Callable] = None, + active_agent_id_and_sensor_name: Tuple[int, str] = (0, "color_sensor"), + ) -> None: + """ + Calls continuously to re-render frames and swap the two frame buffers + at a fixed rate. + """ + agent_acts_per_sec = self.fps + + mn.gl.default_framebuffer.clear( + mn.gl.FramebufferClear.COLOR | mn.gl.FramebufferClear.DEPTH + ) + + # Agent actions should occur at a fixed rate per second + self.time_since_last_simulation += Timer.prev_frame_duration + num_agent_actions: int = self.time_since_last_simulation * agent_acts_per_sec + self.move_and_look(int(num_agent_actions)) + + # Occasionally a frame will pass quicker than 1/60 seconds + if self.time_since_last_simulation >= 1.0 / self.fps: + if self.simulating or self.simulate_single_step: + self.sim.step_world(1.0 / self.fps) + self.simulate_single_step = False + if simulation_call is not None: + simulation_call() + if global_call is not None: + global_call() + if self.navmesh_dirty: + self.navmesh_config_and_recompute() + self.navmesh_dirty = False + + # reset time_since_last_simulation, accounting for potential overflow + self.time_since_last_simulation = math.fmod( + self.time_since_last_simulation, 1.0 / self.fps + ) + + keys = active_agent_id_and_sensor_name + + # set agent position relative to spot + x_rot = mn.Quaternion.rotation( + mn.Rad(self.camera_angles[0]), mn.Vector3(1, 0, 0) + ) + y_rot = mn.Quaternion.rotation( + mn.Rad(self.camera_angles[1]), mn.Vector3(0, 1, 0) + ) + local_camera_vec = mn.Vector3(0, 0, 1) + local_camera_position = y_rot.transform_vector( + x_rot.transform_vector(local_camera_vec * self.camera_distance) + ) + camera_position = local_camera_position + self.spot.base_pos + self.default_agent.scene_node.transformation = mn.Matrix4.look_at( + camera_position, + self.spot.base_pos, + mn.Vector3(0, 1, 0), + ) + + if self.enable_batch_renderer: + self.render_batch() + else: + self.sim._Simulator__sensors[keys[0]][keys[1]].draw_observation() + agent = self.sim.get_agent(keys[0]) + self.render_camera = agent.scene_node.node_sensor_suite.get(keys[1]) + self.debug_draw() + self.render_camera.render_target.blit_rgba_to_default() + + # draw CPU/GPU usage data and other info to the app window + mn.gl.default_framebuffer.bind() + self.draw_text(self.render_camera.specification()) + + self.swap_buffers() + Timer.next_frame() + self.redraw() + + def default_agent_config(self) -> habitat_sim.agent.AgentConfiguration: + """ + Set up our own agent and agent controls + """ + make_action_spec = habitat_sim.agent.ActionSpec + make_actuation_spec = habitat_sim.agent.ActuationSpec + MOVE, LOOK = 0.07, 1.5 + + # all of our possible actions' names + action_list = [ + "move_left", + "turn_left", + "move_right", + "turn_right", + "move_backward", + "look_up", + "move_forward", + "look_down", + "move_down", + "move_up", + ] + + action_space: Dict[str, habitat_sim.agent.ActionSpec] = {} + + # build our action space map + for action in action_list: + actuation_spec_amt = MOVE if "move" in action else LOOK + action_spec = make_action_spec( + action, make_actuation_spec(actuation_spec_amt) + ) + action_space[action] = action_spec + + sensor_spec: List[habitat_sim.sensor.SensorSpec] = self.cfg.agents[ + self.agent_id + ].sensor_specifications + + agent_config = habitat_sim.agent.AgentConfiguration( + height=1.5, + radius=0.1, + sensor_specifications=sensor_spec, + action_space=action_space, + body_type="cylinder", + ) + return agent_config + + def reconfigure_sim(self) -> None: + """ + Utilizes the current `self.sim_settings` to configure and set up a new + `habitat_sim.Simulator`, and then either starts a simulation instance, or replaces + the current simulator instance, reloading the most recently loaded scene + """ + # configure our sim_settings but then set the agent to our default + self.cfg = make_cfg(self.sim_settings) + self.agent_id: int = self.sim_settings["default_agent"] + self.cfg.agents[self.agent_id] = self.default_agent_config() + + if self.enable_batch_renderer: + self.cfg.enable_batch_renderer = True + self.cfg.sim_cfg.create_renderer = False + self.cfg.sim_cfg.enable_gfx_replay_save = True + + if self.sim_settings["use_default_lighting"]: + logger.info("Setting default lighting override for scene.") + self.cfg.sim_cfg.override_scene_light_defaults = True + self.cfg.sim_cfg.scene_light_setup = habitat_sim.gfx.DEFAULT_LIGHTING_KEY + + if self.sim is None: + self.tiled_sims = [] + for _i in range(self.num_env): + self.tiled_sims.append(habitat_sim.Simulator(self.cfg)) + self.sim = self.tiled_sims[0] + else: # edge case + for i in range(self.num_env): + if ( + self.tiled_sims[i].config.sim_cfg.scene_id + == self.cfg.sim_cfg.scene_id + ): + # we need to force a reset, so change the internal config scene name + self.tiled_sims[i].config.sim_cfg.scene_id = "NONE" + self.tiled_sims[i].reconfigure(self.cfg) + + # post reconfigure + self.default_agent = self.sim.get_agent(self.agent_id) + self.render_camera = self.default_agent.scene_node.node_sensor_suite.get( + "color_sensor" + ) + + # set sim_settings scene name as actual loaded scene + self.sim_settings["scene"] = self.sim.curr_scene_name + + # Initialize replay renderer + if self.enable_batch_renderer and self.replay_renderer is None: + self.replay_renderer_cfg = ReplayRendererConfiguration() + self.replay_renderer_cfg.num_environments = self.num_env + self.replay_renderer_cfg.standalone = ( + False # Context is owned by the GLFW window + ) + self.replay_renderer_cfg.sensor_specifications = self.cfg.agents[ + self.agent_id + ].sensor_specifications + self.replay_renderer_cfg.gpu_device_id = self.cfg.sim_cfg.gpu_device_id + self.replay_renderer_cfg.force_separate_semantic_scene_graph = False + self.replay_renderer_cfg.leave_context_with_background_renderer = False + self.replay_renderer = ReplayRenderer.create_batch_replay_renderer( + self.replay_renderer_cfg + ) + # Pre-load composite files + if sim_settings["composite_files"] is not None: + for composite_file in sim_settings["composite_files"]: + self.replay_renderer.preload_file(composite_file) + + # check that clearing joint positions on save won't corrupt the content + for ao in ( + self.sim.get_articulated_object_manager() + .get_objects_by_handle_substring() + .values() + ): + for joint_val in ao.joint_positions: + assert ( + joint_val == 0 + ), "If this fails, there are non-zero joint positions in the scene_instance or default pose. Export with 'i' will clear these." + + self.init_spot() + + Timer.start() + self.step = -1 + + def init_spot(self): + # add the robot to the world via the wrapper + robot_path = SPOT_DIR + agent_config = DictConfig({"articulated_agent_urdf": robot_path}) + self.spot = spot_robot.SpotRobot(agent_config, self.sim, fixed_base=True) + self.spot.reconfigure() + self.spot.update() + self.spot_action = ExtractedBaseVelNonCylinderAction(self.sim, self.spot) + + def render_batch(self): + """ + This method updates the replay manager with the current state of environments and renders them. + """ + for i in range(self.num_env): + # Apply keyframe + keyframe = self.tiled_sims[i].gfx_replay_manager.extract_keyframe() + self.replay_renderer.set_environment_keyframe(i, keyframe) + # Copy sensor transforms + sensor_suite = self.tiled_sims[i]._sensors + for sensor_uuid, sensor in sensor_suite.items(): + transform = sensor._sensor_object.node.absolute_transformation() + self.replay_renderer.set_sensor_transform(i, sensor_uuid, transform) + # Render + self.replay_renderer.render(mn.gl.default_framebuffer) + + def move_and_look(self, repetitions: int) -> None: + """ + This method is called continuously with `self.draw_event` to monitor + any changes in the movement keys map `Dict[KeyEvent.key, Bool]`. + When a key in the map is set to `True` the corresponding action is taken. + """ + # avoids unnecessary updates to grabber's object position + if repetitions == 0: + return + + key = Application.KeyEvent.Key + press: Dict[Application.KeyEvent.Key.key, bool] = self.pressed + + inc = 0.02 + min_val = 0.1 + + if press[key.W] and not press[key.S]: + self.spot_forward = max(min_val, self.spot_forward + inc) + elif press[key.S] and not press[key.W]: + self.spot_forward = min(-min_val, self.spot_forward - inc) + else: + self.spot_forward /= 2.0 + if abs(self.spot_forward) < min_val: + self.spot_forward = 0 + + if press[key.Q] and not press[key.E]: + self.spot_lateral = max(min_val, self.spot_lateral + inc) + elif press[key.E] and not press[key.Q]: + self.spot_lateral = min(-min_val, self.spot_lateral - inc) + else: + self.spot_lateral /= 2.0 + if abs(self.spot_lateral) < min_val: + self.spot_lateral = 0 + + if press[key.A] and not press[key.D]: + self.spot_angular = max(min_val, self.spot_angular + inc) + elif press[key.D] and not press[key.A]: + self.spot_angular = min(-min_val, self.spot_angular - inc) + else: + self.spot_angular /= 2.0 + if abs(self.spot_angular) < min_val: + self.spot_angular = 0 + + self.spot_action.step( + forward=self.spot_forward, + lateral=self.spot_lateral, + angular=self.spot_angular, + ) + + def invert_gravity(self) -> None: + """ + Sets the gravity vector to the negative of it's previous value. This is + a good method for testing simulation functionality. + """ + gravity: mn.Vector3 = self.sim.get_gravity() * -1 + self.sim.set_gravity(gravity) + + def move_selected_object( + self, + translation: Optional[mn.Vector3] = None, + rotation: Optional[mn.Quaternion] = None, + ): + """ + Move the selected object with a given modification and save the resulting state to the buffer. + """ + modify_buffer = translation is not None or rotation is not None + if self.selected_object is not None and modify_buffer: + orig_mt = self.selected_object.motion_type + self.selected_object.motion_type = habitat_sim.physics.MotionType.KINEMATIC + if translation is not None: + self.selected_object.translation = ( + self.selected_object.translation + translation + ) + if rotation is not None: + self.selected_object.rotation = rotation * self.selected_object.rotation + self.selected_object.motion_type = orig_mt + self.navmesh_dirty = True + self.modified_objects_buffer[ + self.selected_object + ] = self.selected_object.transformation + + def key_press_event(self, event: Application.KeyEvent) -> None: + """ + Handles `Application.KeyEvent` on a key press by performing the corresponding functions. + If the key pressed is part of the movement keys map `Dict[KeyEvent.key, Bool]`, then the + key will be set to False for the next `self.move_and_look()` to update the current actions. + """ + key = event.key + pressed = Application.KeyEvent.Key + mod = Application.InputEvent.Modifier + + shift_pressed = bool(event.modifiers & mod.SHIFT) + alt_pressed = bool(event.modifiers & mod.ALT) + # warning: ctrl doesn't always pass through with other key-presses + + if key == pressed.ESC: + event.accepted = True + self.exit_event(Application.ExitEvent) + return + + elif key == pressed.H: + self.print_help_text() + + elif key == pressed.TAB: + pass + + elif key == pressed.SPACE: + if not self.sim.config.sim_cfg.enable_physics: + logger.warn("Warning: physics was not enabled during setup") + else: + self.simulating = not self.simulating + logger.info(f"Command: physics simulating set to {self.simulating}") + + elif key == pressed.PERIOD: + if self.simulating: + logger.warn("Warning: physics simulation already running") + else: + self.simulate_single_step = True + logger.info("Command: physics step taken") + + elif key == pressed.COMMA: + self.debug_bullet_draw = not self.debug_bullet_draw + logger.info(f"Command: toggle Bullet debug draw: {self.debug_bullet_draw}") + + elif key == pressed.LEFT: + # if movement mode + if self.curr_edit_mode == EditMode.MOVE: + self.move_selected_object( + translation=mn.Vector3.x_axis() * self.edit_translation_dist + ) + # if rotation mode : rotate around y axis + else: + self.move_selected_object( + rotation=mn.Quaternion.rotation( + mn.Rad(self.edit_rotation_amt), mn.Vector3.y_axis() + ) + ) + elif key == pressed.RIGHT: + # if movement mode + if self.curr_edit_mode == EditMode.MOVE: + self.move_selected_object( + translation=-mn.Vector3.x_axis() * self.edit_translation_dist + ) + # if rotation mode : rotate around y axis + else: + self.move_selected_object( + rotation=mn.Quaternion.rotation( + -mn.Rad(self.edit_rotation_amt), mn.Vector3.y_axis() + ) + ) + elif key == pressed.UP: + # if movement mode + if self.curr_edit_mode == EditMode.MOVE: + if alt_pressed: + self.move_selected_object( + translation=mn.Vector3.y_axis() * self.edit_translation_dist + ) + else: + self.move_selected_object( + translation=mn.Vector3.z_axis() * self.edit_translation_dist + ) + # if rotation mode : rotate around x or z axis + else: + if alt_pressed: + # rotate around x axis + self.move_selected_object( + rotation=mn.Quaternion.rotation( + mn.Rad(self.edit_rotation_amt), mn.Vector3.x_axis() + ) + ) + else: + # rotate around z axis + self.move_selected_object( + rotation=mn.Quaternion.rotation( + mn.Rad(self.edit_rotation_amt), mn.Vector3.z_axis() + ) + ) + + elif key == pressed.DOWN: + # if movement mode + if self.curr_edit_mode == EditMode.MOVE: + if alt_pressed: + self.move_selected_object( + translation=-mn.Vector3.y_axis() * self.edit_translation_dist + ) + else: + self.move_selected_object( + translation=-mn.Vector3.z_axis() * self.edit_translation_dist + ) + # if rotation mode : rotate around x or z axis + else: + if alt_pressed: + # rotate around x axis + self.move_selected_object( + rotation=mn.Quaternion.rotation( + -mn.Rad(self.edit_rotation_amt), mn.Vector3.x_axis() + ) + ) + else: + # rotate around z axis + self.move_selected_object( + rotation=mn.Quaternion.rotation( + -mn.Rad(self.edit_rotation_amt), mn.Vector3.z_axis() + ) + ) + + elif key == pressed.BACKSPACE or key == pressed.C: + if self.selected_object is not None: + if key == pressed.C: + obj_name = self.selected_object.handle.split("/")[-1].split("_:")[0] + self.removed_clutter.append(obj_name) + print(f"Removed {self.selected_object.handle}") + if isinstance( + self.selected_object, habitat_sim.physics.ManagedBulletRigidObject + ): + self.sim.get_rigid_object_manager().remove_object_by_handle( + self.selected_object.handle + ) + else: + self.sim.get_articulated_object_manager().remove_object_by_handle( + self.selected_object.handle + ) + self.selected_object = None + self.navmesh_config_and_recompute() + elif key == pressed.B: + # cycle through edit dist/amount multiplier + mod_val = -1 if shift_pressed else 1 + self.curr_edit_multiplier = DistanceMode( + ( + self.curr_edit_multiplier.value + + DistanceMode.NUM_VALS.value + + mod_val + ) + % DistanceMode.NUM_VALS.value + ) + # update the edit values + self.set_edit_vals() + + elif key == pressed.G: + # toggle edit mode + mod_val = -1 if shift_pressed else 1 + self.curr_edit_mode = EditMode( + (self.curr_edit_mode.value + EditMode.NUM_VALS.value + mod_val) + % EditMode.NUM_VALS.value + ) + + elif key == pressed.I: + # dump the modified object states buffer to JSON. + # print(f"Writing modified_objects_buffer to 'scene_mod_buffer.json': {self.modified_objects_buffer}") + # with open("scene_mod_buffer.json", "w") as f: + # f.write(json.dumps(self.modified_objects_buffer, indent=2)) + aom = self.sim.get_articulated_object_manager() + spot_loc = self.spot.sim_obj.rigid_state + aom.remove_object_by_handle(self.spot.sim_obj.handle) + + # clear furniture joint positions before saving + self.clear_furniture_joint_states() + + self.sim.save_current_scene_config(overwrite=True) + print("Saved modified scene instance JSON to original location.") + # de-duplicate and save clutter list + self.removed_clutter = list(dict.fromkeys(self.removed_clutter)) + with open("removed_clutter.txt", "a") as f: + for obj_name in self.removed_clutter: + f.write(obj_name + "\n") + # only exit if shift pressed + if shift_pressed: + event.accepted = True + self.exit_event(Application.ExitEvent) + return + # rebuild spot + self.init_spot() + # put em back + self.spot.sim_obj.rigid_state = spot_loc + + elif key == pressed.J: + if shift_pressed and isinstance( + self.selected_object, habitat_sim.physics.ManagedArticulatedObject + ): + # open the selected receptacle + for link_ix in self.selected_object.get_link_ids(): + if self.selected_object.get_link_joint_type(link_ix) in [ + habitat_sim.physics.JointType.Prismatic, + habitat_sim.physics.JointType.Revolute, + ]: + sutils.open_link(self.selected_object, link_ix) + else: + self.clear_furniture_joint_states() + self.navmesh_config_and_recompute() + + elif key == pressed.N: + # (default) - toggle navmesh visualization + # NOTE: (+ALT) - re-sample the agent position on the NavMesh + # NOTE: (+SHIFT) - re-compute the NavMesh + if alt_pressed: + logger.info("Command: resample agent state from navmesh") + self.place_spot() + elif shift_pressed: + logger.info("Command: recompute navmesh") + self.navmesh_config_and_recompute() + else: + if self.sim.pathfinder.is_loaded: + self.sim.navmesh_visualization = not self.sim.navmesh_visualization + logger.info("Command: toggle navmesh") + else: + logger.warn("Warning: recompute navmesh first") + + elif key == pressed.T: + self.remove_outdoor_objects() + pass + + elif key == pressed.U: + # if an object is selected, restore its last transformation state - UNDO of edits since last selected + print("Undo selected") + if self.selected_object is not None: + print( + f"Sel Obj : {self.selected_object.handle} : Current object transformation : \n{self.selected_object.transformation}\n Being replaced by saved transformation : \n{self.selected_object.transformation}" + ) + orig_mt = self.selected_object.motion_type + self.selected_object.motion_type = ( + habitat_sim.physics.MotionType.KINEMATIC + ) + self.selected_object.transformation = ( + self.selected_object_orig_transform + ) + self.selected_object.motion_type = orig_mt + + elif key == pressed.V: + # inject a new AO by handle substring in front of the agent + + # get user input + ao_substring = input( + "Load ArticulatedObject. Enter an AO handle substring, first match will be added:" + ).strip() + + aotm = self.sim.metadata_mediator.ao_template_manager + aom = self.sim.get_articulated_object_manager() + ao_handles = aotm.get_template_handles(ao_substring) + if len(ao_handles) == 0: + print(f"No AO found matching substring: '{ao_substring}'") + return + elif len(ao_handles) > 1: + print(f"Multiple AOs found matching substring: '{ao_substring}'.") + matching_ao_handle = ao_handles[0] + print(f"Adding AO: '{matching_ao_handle}'") + aot = aotm.get_template_by_handle(matching_ao_handle) + aot.base_type = "FIXED" + aotm.register_template(aot) + ao = aom.add_articulated_object_by_template_handle(matching_ao_handle) + if ao is not None: + recompute_ao_bbs(ao) + in_front_of_spot = self.spot.base_transformation.transform_point( + [1.5, 0.0, 0.0] + ) + ao.translation = in_front_of_spot + else: + print("Failed to load AO.") + + # update map of moving/looking keys which are currently pressed + if key in self.pressed: + self.pressed[key] = True + event.accepted = True + self.redraw() + + def key_release_event(self, event: Application.KeyEvent) -> None: + """ + Handles `Application.KeyEvent` on a key release. When a key is released, if it + is part of the movement keys map `Dict[KeyEvent.key, Bool]`, then the key will + be set to False for the next `self.move_and_look()` to update the current actions. + """ + key = event.key + + # update map of moving/looking keys which are currently pressed + if key in self.pressed: + self.pressed[key] = False + event.accepted = True + self.redraw() + + def mouse_move_event(self, event: Application.MouseMoveEvent) -> None: + """ + Handles `Application.MouseMoveEvent`. When in LOOK mode, enables the left + mouse button to steer the agent's facing direction. When in GRAB mode, + continues to update the grabber's object position with our agents position. + """ + button = Application.MouseMoveEvent.Buttons + # if interactive mode -> LOOK MODE + if event.buttons == button.LEFT: + self.camera_angles[0] -= float(event.relative_position[1]) * 0.01 + self.camera_angles[1] -= float(event.relative_position[0]) * 0.01 + self.camera_angles[0] = max(-1.55, min(0.5, self.camera_angles[0])) + self.camera_angles[1] = math.fmod(self.camera_angles[1], math.pi * 2) + + self.previous_mouse_point = self.get_mouse_position(event.position) + self.redraw() + event.accepted = True + + def mouse_press_event(self, event: Application.MouseEvent) -> None: + """ + Handles `Application.MouseEvent`. When in GRAB mode, click on + objects to drag their position. (right-click for fixed constraints) + """ + button = Application.MouseEvent.Button + physics_enabled = self.sim.get_physics_simulation_library() + mod = Application.InputEvent.Modifier + shift_pressed = bool(event.modifiers & mod.SHIFT) + + # select an object with Shift+RIGHT-click + if physics_enabled and event.button == button.RIGHT and shift_pressed: + self.selected_object = None + render_camera = self.render_camera.render_camera + ray = render_camera.unproject(self.get_mouse_position(event.position)) + mouse_cast_results = self.sim.cast_ray(ray=ray) + if mouse_cast_results.has_hits(): + # find first non-stage object + hit_idx = 0 + obj_found = False + while hit_idx < len(mouse_cast_results.hits) and not obj_found: + self.last_hit_details = mouse_cast_results.hits[hit_idx] + hit_obj_id = mouse_cast_results.hits[hit_idx].object_id + self.selected_object = sutils.get_obj_from_id(self.sim, hit_obj_id) + if self.selected_object is None: + hit_idx += 1 + else: + obj_found = True + if obj_found: + print( + f"Object: {self.selected_object.handle} is {type(self.selected_object)}" + ) + else: + print("This is the stage.") + # record current selected object's transformation, to restore if undo is pressed + if self.selected_object is not None: + self.selected_object_orig_transform = ( + self.selected_object.transformation + ) + + self.previous_mouse_point = self.get_mouse_position(event.position) + self.redraw() + event.accepted = True + + def mouse_scroll_event(self, event: Application.MouseScrollEvent) -> None: + """ + Handles `Application.MouseScrollEvent`. When in LOOK mode, enables camera + zooming (fine-grained zoom using shift) When in GRAB mode, adjusts the depth + of the grabber's object. (larger depth change rate using shift) + """ + scroll_mod_val = ( + event.offset.y + if abs(event.offset.y) > abs(event.offset.x) + else event.offset.x + ) + if not scroll_mod_val: + return + + # use shift to scale action response + shift_pressed = bool(event.modifiers & Application.InputEvent.Modifier.SHIFT) + # alt_pressed = bool(event.modifiers & Application.InputEvent.Modifier.ALT) + # ctrl_pressed = bool(event.modifiers & Application.InputEvent.Modifier.CTRL) + + # LOOK MODE + # use shift for fine-grained zooming + mod_val = 0.3 if shift_pressed else 0.15 + scroll_delta = scroll_mod_val * mod_val + self.camera_distance -= scroll_delta + + self.redraw() + event.accepted = True + + def mouse_release_event(self, event: Application.MouseEvent) -> None: + """ + Release any existing constraints. + """ + event.accepted = True + + def get_mouse_position(self, mouse_event_position: mn.Vector2i) -> mn.Vector2i: + """ + This function will get a screen-space mouse position appropriately + scaled based on framebuffer size and window size. Generally these would be + the same value, but on certain HiDPI displays (Retina displays) they may be + different. + """ + scaling = mn.Vector2i(self.framebuffer_size) / mn.Vector2i(self.window_size) + return mouse_event_position * scaling + + def navmesh_config_and_recompute(self) -> None: + """ + This method is setup to be overridden in for setting config accessibility + in inherited classes. + """ + self.navmesh_settings = habitat_sim.NavMeshSettings() + self.navmesh_settings.set_defaults() + self.navmesh_settings.agent_height = self.cfg.agents[self.agent_id].height + self.navmesh_settings.agent_radius = 0.3 + self.navmesh_settings.include_static_objects = True + + # first cache AO motion types and set to STATIC for navmesh + ao_motion_types = {} + for ao in ( + self.sim.get_articulated_object_manager() + .get_objects_by_handle_substring() + .values() + ): + # ignore the robot + if "hab_spot" not in ao.handle: + ao_motion_types[ao.handle] = ao.motion_type + ao.motion_type = habitat_sim.physics.MotionType.STATIC + + self.sim.recompute_navmesh(self.sim.pathfinder, self.navmesh_settings) + + # reset AO motion types from cache + for ao in ( + self.sim.get_articulated_object_manager() + .get_objects_by_handle_substring() + .values() + ): + # ignore the robot + if ao.handle in ao_motion_types: + ao.motion_type = ao_motion_types[ao.handle] + + def exit_event(self, event: Application.ExitEvent): + """ + Overrides exit_event to properly close the Simulator before exiting the + application. + """ + for i in range(self.num_env): + self.tiled_sims[i].close(destroy=True) + event.accepted = True + exit(0) + + def draw_text(self, sensor_spec): + # make magnum text background transparent for text + mn.gl.Renderer.enable(mn.gl.Renderer.Feature.BLENDING) + mn.gl.Renderer.set_blend_function( + mn.gl.Renderer.BlendFunction.ONE, + mn.gl.Renderer.BlendFunction.ONE_MINUS_SOURCE_ALPHA, + ) + + self.shader.bind_vector_texture(self.glyph_cache.texture) + self.shader.transformation_projection_matrix = self.window_text_transform + self.shader.color = [1.0, 1.0, 1.0] + + sensor_type_string = str(sensor_spec.sensor_type.name) + sensor_subtype_string = str(sensor_spec.sensor_subtype.name) + edit_mode_string = EDIT_MODE_NAMES[self.curr_edit_mode.value] + + dist_mode_substr = ( + f"Translation: {self.edit_translation_dist}m" + if self.curr_edit_mode == EditMode.MOVE + else f"Rotation:{ROTATION_MULT_VALS[self.curr_edit_multiplier.value]} deg " + ) + edit_distance_mode_string = f"{dist_mode_substr}" + self.window_text.render( + f""" +{self.fps} FPS +Scene ID : {os.path.split(self.cfg.sim_cfg.scene_id)[1].split('.scene_instance')[0]} +Sensor Type: {sensor_type_string} +Sensor Subtype: {sensor_subtype_string} +Edit Mode: {edit_mode_string} +Edit Value: {edit_distance_mode_string} + """ + ) + self.shader.draw(self.window_text.mesh) + + # Disable blending for text + mn.gl.Renderer.disable(mn.gl.Renderer.Feature.BLENDING) + + def print_help_text(self) -> None: + """ + Print the Key Command help text. + """ + logger.info( + """ +===================================================== +Welcome to the Habitat-sim Python Spot Viewer application! +===================================================== +Mouse Functions +---------------- +In LOOK mode (default): + LEFT: + Click and drag to rotate the view around Spot. + WHEEL: + Zoom in and out on Spot view. + + +Key Commands: +------------- + esc: Exit the application. + 'h': Display this help message. + + Spot Controls: + 'wasd': Move Spot's body forward/backward and rotate left/right. + 'qe': Move Spot's body in strafe left/right. + + Scene Object Modification UI: + 'SHIFT+right-click': Select an object to modify. + 'g' : Change Edit mode to either Move or Rotate the selected object + 'b' (+ SHIFT) : Increment (Decrement) the current edit amounts. + - With an object selected: + When Move Object mode is selected : + - LEFT/RIGHT arrow keys: move the object along global X axis. + - UP/DOWN arrow keys: move the object along global Z axis. + (+ALT): move the object up/down (global Y axis) + When Rotate Object mode is selected : + - LEFT/RIGHT arrow keys: rotate the object around global Y axis. + - UP/DOWN arrow keys: rotate the object around global Z axis. + (+ALT): rotate the object around global X axis. + - BACKSPACE: delete the selected object + - 'c': delete the selected object and record it as clutter. + 'i': save the current, modified, scene_instance file. Also save removed_clutter.txt containing object names of all removed clutter objects. + - With Shift : also close the viewer. + + Utilities: + 'r': Reset the simulator with the most recently loaded scene. + 'n': Show/hide NavMesh wireframe. + (+SHIFT) Recompute NavMesh with Spot settings (already done). + (+ALT) Re-sample Spot's position from the NavMesh. + ',': Render a Bullet collision shape debug wireframe overlay (white=active, green=sleeping, blue=wants sleeping, red=can't sleep). + + Object Interactions: + SPACE: Toggle physics simulation on/off. + '.': Take a single simulation step if not simulating continuously. +===================================================== +""" + ) + + +class Timer: + """ + Timer class used to keep track of time between buffer swaps + and guide the display frame rate. + """ + + start_time = 0.0 + prev_frame_time = 0.0 + prev_frame_duration = 0.0 + running = False + + @staticmethod + def start() -> None: + """ + Starts timer and resets previous frame time to the start time. + """ + Timer.running = True + Timer.start_time = time.time() + Timer.prev_frame_time = Timer.start_time + Timer.prev_frame_duration = 0.0 + + @staticmethod + def stop() -> None: + """ + Stops timer and erases any previous time data, resetting the timer. + """ + Timer.running = False + Timer.start_time = 0.0 + Timer.prev_frame_time = 0.0 + Timer.prev_frame_duration = 0.0 + + @staticmethod + def next_frame() -> None: + """ + Records previous frame duration and updates the previous frame timestamp + to the current time. If the timer is not currently running, perform nothing. + """ + if not Timer.running: + return + Timer.prev_frame_duration = time.time() - Timer.prev_frame_time + Timer.prev_frame_time = time.time() + + +if __name__ == "__main__": + import argparse + + parser = argparse.ArgumentParser() + + # optional arguments + parser.add_argument( + "--scene", + default="./data/test_assets/scenes/simple_room.glb", + type=str, + help='scene/stage file to load (default: "./data/test_assets/scenes/simple_room.glb")', + ) + parser.add_argument( + "--dataset", + default="./data/objects/ycb/ycb.scene_dataset_config.json", + type=str, + metavar="DATASET", + help='dataset configuration file to use (default: "./data/objects/ycb/ycb.scene_dataset_config.json")', + ) + parser.add_argument( + "--disable-physics", + action="store_true", + help="disable physics simulation (default: False)", + ) + parser.add_argument( + "--use-default-lighting", + action="store_true", + help="Override configured lighting to use default lighting for the stage.", + ) + parser.add_argument( + "--hbao", + action="store_true", + help="Enable horizon-based ambient occlusion, which provides soft shadows in corners and crevices.", + ) + parser.add_argument( + "--enable-batch-renderer", + action="store_true", + help="Enable batch rendering mode. The number of concurrent environments is specified with the num-environments parameter.", + ) + parser.add_argument( + "--num-environments", + default=1, + type=int, + help="Number of concurrent environments to batch render. Note that only the first environment simulates physics and can be controlled.", + ) + parser.add_argument( + "--composite-files", + type=str, + nargs="*", + help="Composite files that the batch renderer will use in-place of simulation assets to improve memory usage and performance. If none is specified, the original scene files will be loaded from disk.", + ) + parser.add_argument( + "--width", + default=1080, + type=int, + help="Horizontal resolution of the window.", + ) + parser.add_argument( + "--height", + default=720, + type=int, + help="Vertical resolution of the window.", + ) + + args = parser.parse_args() + + if args.num_environments < 1: + parser.error("num-environments must be a positive non-zero integer.") + if args.width < 1: + parser.error("width must be a positive non-zero integer.") + if args.height < 1: + parser.error("height must be a positive non-zero integer.") + + # Setting up sim_settings + sim_settings: Dict[str, Any] = default_sim_settings + sim_settings["scene"] = args.scene + sim_settings["scene_dataset_config_file"] = args.dataset + sim_settings["enable_physics"] = not args.disable_physics + sim_settings["use_default_lighting"] = args.use_default_lighting + sim_settings["enable_batch_renderer"] = args.enable_batch_renderer + sim_settings["num_environments"] = args.num_environments + sim_settings["composite_files"] = args.composite_files + sim_settings["window_width"] = args.width + sim_settings["window_height"] = args.height + sim_settings["sensor_height"] = 0 + sim_settings["enable_hbao"] = args.hbao + + # start the application + HabitatSimInteractiveViewer(sim_settings).exec() diff --git a/examples/viewer.py b/examples/viewer.py index 78df8231ba..a2afca9349 100644 --- a/examples/viewer.py +++ b/examples/viewer.py @@ -3,6 +3,7 @@ # LICENSE file in the root directory of this source tree. import ctypes +import json import math import os import string @@ -14,8 +15,16 @@ flags = sys.getdlopenflags() sys.setdlopenflags(flags | ctypes.RTLD_GLOBAL) +import habitat.datasets.rearrange.samplers.receptacle as hab_receptacle +import habitat.sims.habitat_simulator.sim_utilities as sutils import magnum as mn import numpy as np +from habitat.datasets.rearrange.navmesh_utils import ( + get_largest_island_index, + unoccluded_navmesh_snap, +) +from habitat.datasets.rearrange.samplers.object_sampler import ObjectSampler +from habitat.sims.habitat_simulator.debug_visualizer import DebugVisualizer from magnum import shaders, text from magnum.platform.glfw import Application @@ -25,6 +34,59 @@ from habitat_sim.utils.common import quat_from_angle_axis from habitat_sim.utils.settings import default_sim_settings, make_cfg +# add tools directory so I can import things to try them in the viewer +sys.path.append(os.path.join(os.path.dirname(os.path.abspath(__file__)), "../tools")) +print(sys.path) +import collision_shape_automation as csa + +# CollisionProxyOptimizer initialized before the application +_cpo: Optional[csa.CollisionProxyOptimizer] = None +_cpo_threads = [] + + +def _cpo_initialized(): + global _cpo + global _cpo_threads + if _cpo is None: + return False + return all(not thread.is_alive() for thread in _cpo_threads) + + +class RecColorMode(Enum): + """ + Defines the coloring mode for receptacle debug drawing. + """ + + DEFAULT = 0 # all magenta + GT_ACCESS = 1 # red to green + GT_STABILITY = 2 + PR_ACCESS = 3 + PR_STABILITY = 4 + FILTERING = 5 # colored by filter status (green=active, yellow=manually filtered, red=automatically filtered (access), magenta=automatically filtered (access), blue=automatically filtered (height)) + + +class ColorLERP: + """ + xyz lerp between two colors. + """ + + def __init__(self, c0: mn.Color4, c1: mn.Color4): + self.c0 = c0.to_xyz() + self.c1 = c1.to_xyz() + self.delta = self.c1 - self.c0 + + def at(self, t: float) -> mn.Color4: + """ + Compute the LERP at time t [0,1]. + """ + assert t >= 0 and t <= 1, "Extrapolation not recommended in color space." + t_color_xyz = self.c0 + self.delta * t + return mn.Color4.from_xyz(t_color_xyz) + + +# red to green lerp for heatmaps +rg_lerp = ColorLERP(mn.Color4.red(), mn.Color4.green()) + class HabitatSimInteractiveViewer(Application): # the maximum number of chars displayable in the app window @@ -44,7 +106,11 @@ class HabitatSimInteractiveViewer(Application): # CPU and GPU usage info DISPLAY_FONT_SIZE = 16.0 - def __init__(self, sim_settings: Dict[str, Any]) -> None: + def __init__( + self, + sim_settings: Dict[str, Any], + mm: Optional[habitat_sim.metadata.MetadataMediator] = None, + ) -> None: self.sim_settings: Dict[str:Any] = sim_settings self.enable_batch_renderer: bool = self.sim_settings["enable_batch_renderer"] @@ -72,16 +138,6 @@ def __init__(self, sim_settings: Dict[str, Any]) -> None: self.sim_settings["width"] = camera_resolution[0] self.sim_settings["height"] = camera_resolution[1] - # draw Bullet debug line visualizations (e.g. collision meshes) - self.debug_bullet_draw = False - # draw active contact point debug line visualizations - self.contact_debug_draw = False - # draw semantic region debug visualizations if present - self.semantic_region_debug_draw = False - - # cache most recently loaded URDF file for quick-reload - self.cached_urdf = "" - # set up our movement map key = Application.KeyEvent.Key self.pressed = { @@ -159,6 +215,34 @@ def __init__(self, sim_settings: Dict[str, Any]) -> None: # variables that track app data and CPU/GPU usage self.num_frames_to_track = 60 + # Descriptive strings for semantic region debug draw possible choices + self.semantic_region_debug_draw_choices = ["None", "Kitchen Only", "All"] + + global _cpo + self._cpo = _cpo + self.cpo_initialized = False + self.proxy_obj_postfix = "_collision_stand-in" + + # initialization code below here + # TODO isolate all initialization so tabbing through scenes can be properly supported + # configure our simulator + self.cfg: Optional[habitat_sim.simulator.Configuration] = None + self.sim: Optional[habitat_sim.simulator.Simulator] = None + self.tiled_sims: list[habitat_sim.simulator.Simulator] = None + self.replay_renderer_cfg: Optional[ReplayRendererConfiguration] = None + self.replay_renderer: Optional[ReplayRenderer] = None + + # draw Bullet debug line visualizations (e.g. collision meshes) + self.debug_bullet_draw = False + # draw active contact point debug line visualizations + self.contact_debug_draw = False + # draw semantic region debug visualizations if present : should be [0 : len(semantic_region_debug_draw_choices)-1] + self.semantic_region_debug_draw_state = 0 + # Colors to use for each region's semantic rendering. + self.debug_semantic_colors = {} + + # cache most recently loaded URDF file for quick-reload + self.cached_urdf = "" # Cycle mouse utilities self.mouse_interaction = MouseMode.LOOK @@ -166,25 +250,79 @@ def __init__(self, sim_settings: Dict[str, Any]) -> None: self.previous_mouse_point = None # toggle physics simulation on/off - self.simulating = True - + self.simulating = False # toggle a single simulation step at the next opportunity if not # simulating continuously. self.simulate_single_step = False - # configure our simulator - self.cfg: Optional[habitat_sim.simulator.Configuration] = None - self.sim: Optional[habitat_sim.simulator.Simulator] = None - self.tiled_sims: list[habitat_sim.simulator.Simulator] = None - self.replay_renderer_cfg: Optional[ReplayRendererConfiguration] = None - self.replay_renderer: Optional[ReplayRenderer] = None - self.reconfigure_sim() - self.debug_semantic_colors = {} + # receptacle visualization + self.receptacles = None + self.display_receptacles = False + self.show_filtered = True + self.rec_access_filter_threshold = 0.12 # empirically chosen + self.rec_color_mode = RecColorMode.FILTERING + # map receptacle to parent objects + self.rec_to_poth: Dict[hab_receptacle.Receptacle, str] = {} + self.poh_to_rec: Dict[str, List[hab_receptacle.Receptacle]] = {} + # contains filtering metadata and classification of meshes filtered automatically and manually + self.rec_filter_data = None + # TODO need to determine filter path for each scene during tabbing? + # Currently this field is only set as command-line argument + self.rec_filter_path = self.sim_settings["rec_filter_file"] + + # display stability samples for selected object w/ receptacle + self.display_selected_stability_samples = True + + # collision proxy visualization + self.col_proxy_objs = None + self.col_proxies_visible = True + self.original_objs_visible = True + + # mouse raycast visualization + self.mouse_cast_results = None + # last clicked or None for stage + self.selected_object = None + self.selected_rec = None + self.ao_link_map = None + + # index of the largest indoor island + self.largest_island_ix = -1 + + # Sim reconfigure + self.reconfigure_sim(mm) + # load appropriate filter file for scene + self.load_scene_filter_file() + + # ----------------------------------------- + # Clutter Generation Integration: + self.clutter_object_set = [ + "002_master_chef_can", + "003_cracker_box", + "004_sugar_box", + "005_tomato_soup_can", + "007_tuna_fish_can", + "008_pudding_box", + "009_gelatin_box", + "010_potted_meat_can", + "024_bowl", + ] + self.clutter_object_handles = [] + self.clutter_object_instances = [] + # cache initial states for classification of unstable objects + self.clutter_object_initial_states = [] + self.num_unstable_objects = 0 + # add some clutter objects to the MM + self.sim.metadata_mediator.object_template_manager.load_configs( + "data/objects/ycb/configs/" + ) + self.initialize_clutter_object_set() + # ----------------------------------------- # compute NavMesh if not already loaded by the scene. if ( not self.sim.pathfinder.is_loaded and self.cfg.sim_cfg.scene_id.lower() != "none" + and not self.sim_settings["viewer_ignore_navmesh"] ): self.navmesh_config_and_recompute() @@ -193,6 +331,268 @@ def __init__(self, sim_settings: Dict[str, Any]) -> None: logger.setLevel("INFO") self.print_help_text() + def modify_param_from_term(self): + """ + Prompts the user to enter an attribute name and new value. + Attempts to fulfill the user's request. + """ + # first get an attribute + user_attr = input("++++++++++++\nProvide an attribute to edit: ") + if not hasattr(self, user_attr): + print(f" The '{user_attr}' attribute does not exist.") + return + + # then get a value + user_val = input(f"Now provide a value for '{user_attr}': ") + cur_attr_val = getattr(self, user_attr) + if cur_attr_val is not None: + try: + # try type conversion + new_val = type(cur_attr_val)(user_val) + + # special handling for bool because all strings become True with cast + if isinstance(cur_attr_val, bool): + if user_val.lower() == "false": + new_val = False + elif user_val.lower() == "true": + new_val = True + + setattr(self, user_attr, new_val) + print( + f"attr '{user_attr}' set to '{getattr(self, user_attr)}' (type={type(new_val)})." + ) + except Exception: + print(f"Failed to cast '{user_val}' to {type(cur_attr_val)}.") + else: + print("That attribute is unset, so I don't know the type.") + + def load_scene_filter_file(self): + """ + Load the filter file for a scene from config. + """ + + scene_user_defined = self.sim.metadata_mediator.get_scene_user_defined( + self.sim.curr_scene_name + ) + if scene_user_defined is not None and scene_user_defined.has_value( + "scene_filter_file" + ): + scene_filter_file = scene_user_defined.get("scene_filter_file") + # construct the dataset level path for the filter data file + scene_filter_file = os.path.join( + os.path.dirname(mm.active_dataset), scene_filter_file + ) + print(f"scene_filter_file = {scene_filter_file}") + self.load_receptacles() + self.load_filtered_recs(scene_filter_file) + self.rec_filter_path = scene_filter_file + else: + print( + f"WARNING: No rec filter file configured for scene {self.sim.curr_scene_name}." + ) + + def get_closest_receptacle( + self, pos: mn.Vector3, max_dist: float = 3.5 + ) -> Optional[hab_receptacle.TriangleMeshReceptacle]: + """ + Return the closest receptacle to the given position or None. + + :param pos: The point to compare with receptacle verts. + :param max_dist: The maximum allowable distance to the receptacle to count. + + :return: None if failed or closest receptacle. + """ + if self.receptacles is None or not self.display_receptacles: + return None + closest_rec = None + closest_rec_dist = max_dist + recs = ( + self.receptacles + if ( + self.selected_object is None + or self.selected_object.handle not in self.poh_to_rec + ) + else self.poh_to_rec[self.selected_object.handle] + ) + for receptacle in recs: + g_trans = receptacle.get_global_transform(self.sim) + # transform the query point once instead of all verts + local_point = g_trans.inverted().transform_point(pos) + if (g_trans.translation - pos).length() < max_dist: + # receptacles object transform should be close to the point + if isinstance(receptacle, hab_receptacle.TriangleMeshReceptacle): + for vert in receptacle.mesh_data.attribute( + mn.trade.MeshAttribute.POSITION + ): + v_dist = (local_point - vert).length() + if v_dist < closest_rec_dist: + closest_rec_dist = v_dist + closest_rec = receptacle + else: + global_keypoints = None + if isinstance(receptacle, hab_receptacle.AABBReceptacle): + global_keypoints = sutils.get_global_keypoints_from_bb( + receptacle.bounds, g_trans + ) + elif isinstance(receptacle, hab_receptacle.AnyObjectReceptacle): + global_keypoints = sutils.get_bb_corners( + receptacle._get_global_bb(self.sim) + ) + + for g_point in global_keypoints: + v_dist = (pos - g_point).length() + if v_dist < closest_rec_dist: + closest_rec_dist = v_dist + closest_rec = receptacle + + return closest_rec + + def compute_rec_filter_state( + self, + access_threshold: float = 0.12, + stab_threshold: float = 0.5, + filter_shape: str = "pr0", + ) -> None: + """ + Check all receptacles against automated filters to fill the + + :param access_threshold: Access threshold for filtering. Roughly % of sample points with some raycast access. + :param stab_threshold: Stability threshold for filtering. Roughly % of sample points with stable object support. + :param filter_shape: Which shape metrics to use for filter. Choices typically "gt"(ground truth) or "pr0"(proxy shape). + """ + # load receptacles if not done + if self.receptacles is None: + self.load_receptacles() + assert ( + self._cpo is not None + ), "Must initialize the CPO before automatic filtering. Re-run with '--init-cpo'." + + # initialize if necessary + if self.rec_filter_data is None: + self.rec_filter_data = { + "active": [], + "manually_filtered": [], + "access_filtered": [], + "access_threshold": access_threshold, # set in filter procedure + "stability_filtered": [], + "stability threshold": stab_threshold, # set in filter procedure + # TODO: + "height_filtered": [], + "max_height": 0, + "min_height": 0, + } + + for rec in self.receptacles: + rec_unique_name = rec.unique_name + # respect already marked receptacles + if rec_unique_name not in self.rec_filter_data["manually_filtered"]: + rec_dat = self._cpo.gt_data[self.rec_to_poth[rec]]["receptacles"][ + rec.name + ] + rec_shape_data = rec_dat["shape_id_results"][filter_shape] + # filter by access + if ( + "access_results" in rec_shape_data + and rec_shape_data["access_results"]["receptacle_access_score"] + < access_threshold + ): + self.rec_filter_data["access_filtered"].append(rec_unique_name) + # filter by stability + elif ( + "stability_results" in rec_shape_data + and rec_shape_data["stability_results"]["success_ratio"] + < stab_threshold + ): + self.rec_filter_data["stability_filtered"].append(rec_unique_name) + # TODO: add more filters + # TODO: 1. filter by height relative to the floor + # TODO: 2. filter outdoor (raycast up) + # TODO: 3/4: filter by access/stability in scene context (relative to other objects) + # remaining receptacles are active + else: + self.rec_filter_data["active"].append(rec_unique_name) + + def export_filtered_recs(self, filepath: Optional[str] = None) -> None: + """ + Save a JSON with filtering metadata and filtered Receptacles for a scene. + + :param filepath: Defines the output filename for this JSON. If omitted, defaults to "./rec_filter_data.json". + """ + if filepath is None: + filepath = "rec_filter_data.json" + os.makedirs(os.path.dirname(filepath), exist_ok=True) + with open(filepath, "w") as f: + f.write(json.dumps(self.rec_filter_data, indent=2)) + print(f"Exported filter annotations to {filepath}.") + + def load_filtered_recs(self, filepath: Optional[str] = None) -> None: + """ + Load a Receptacle filtering metadata JSON to visualize the state of the scene. + + :param filepath: Defines the input filename for this JSON. If omitted, defaults to "./rec_filter_data.json". + """ + if filepath is None: + filepath = "rec_filter_data.json" + if not os.path.exists(filepath): + print(f"Filtered rec metadata file {filepath} does not exist. Cannot load.") + return + with open(filepath, "r") as f: + self.rec_filter_data = json.load(f) + + # assert the format is correct + assert "active" in self.rec_filter_data + assert "manually_filtered" in self.rec_filter_data + assert "access_filtered" in self.rec_filter_data + assert "stability_filtered" in self.rec_filter_data + assert "height_filtered" in self.rec_filter_data + print(f"Loaded filter annotations from {filepath}") + + def load_receptacles(self): + """ + Load all receptacle data and setup helper datastructures. + """ + self.receptacles = hab_receptacle.find_receptacles(self.sim) + self.receptacles = [ + rec + for rec in self.receptacles + if "collision_stand-in" not in rec.parent_object_handle + ] + for receptacle in self.receptacles: + if receptacle not in self.rec_to_poth: + po_handle = sutils.get_obj_from_handle( + self.sim, receptacle.parent_object_handle + ).creation_attributes.handle + self.rec_to_poth[receptacle] = po_handle + if receptacle.parent_object_handle not in self.poh_to_rec: + self.poh_to_rec[receptacle.parent_object_handle] = [] + self.poh_to_rec[receptacle.parent_object_handle].append(receptacle) + + def add_col_proxy_object( + self, obj_instance: habitat_sim.physics.ManagedRigidObject + ) -> habitat_sim.physics.ManagedRigidObject: + """ + Add a collision object visualization proxy to the scene overlapping with the given object. + Return the new proxy object. + """ + # replace the object with a collision_object + obj_temp_handle = obj_instance.creation_attributes.handle + otm = self.sim.get_object_template_manager() + object_template = otm.get_template_by_handle(obj_temp_handle) + object_template.scale = obj_instance.scale + np.ones(3) * 0.01 + object_template.render_asset_handle = object_template.collision_asset_handle + object_template.is_collidable = False + reg_id = otm.register_template( + object_template, + object_template.handle + self.proxy_obj_postfix, + ) + ro_mngr = self.sim.get_rigid_object_manager() + new_obj = ro_mngr.add_object_by_template_id(reg_id) + new_obj.motion_type = habitat_sim.physics.MotionType.KINEMATIC + new_obj.translation = obj_instance.translation + new_obj.rotation = obj_instance.rotation + self.sim.set_object_bb_draw(True, new_obj.object_id) + return new_obj + def draw_contact_debug(self, debug_line_render: Any): """ This method is called to render a debug line overlay displaying active contact points and normals. @@ -231,15 +631,27 @@ def draw_region_debug(self, debug_line_render: Any) -> None: """ Draw the semantic region wireframes. """ - - for region in self.sim.semantic_scene.regions: - color = self.debug_semantic_colors.get(region.id, mn.Color4.magenta()) - for edge in region.volume_edges: - debug_line_render.draw_transformed_line( - edge[0], - edge[1], - color, - ) + if self.semantic_region_debug_draw_state == 1: + for region in self.sim.semantic_scene.regions: + if "kitchen" not in region.id.lower(): + continue + color = self.debug_semantic_colors.get(region.id, mn.Color4.magenta()) + for edge in region.volume_edges: + debug_line_render.draw_transformed_line( + edge[0], + edge[1], + color, + ) + else: + # Draw all + for region in self.sim.semantic_scene.regions: + color = self.debug_semantic_colors.get(region.id, mn.Color4.magenta()) + for edge in region.volume_edges: + debug_line_render.draw_transformed_line( + edge[0], + edge[1], + color, + ) def debug_draw(self): """ @@ -254,13 +666,173 @@ def debug_draw(self): if self.contact_debug_draw: self.draw_contact_debug(debug_line_render) - if self.semantic_region_debug_draw: + if self.semantic_region_debug_draw_state != 0: if len(self.debug_semantic_colors) != len(self.sim.semantic_scene.regions): + self.debug_semantic_colors = {} for region in self.sim.semantic_scene.regions: self.debug_semantic_colors[region.id] = mn.Color4( mn.Vector3(np.random.random(3)) ) self.draw_region_debug(debug_line_render) + if self.receptacles is not None and self.display_receptacles: + if self.rec_filter_data is None and self.cpo_initialized: + self.compute_rec_filter_state( + access_threshold=self.rec_access_filter_threshold + ) + c_pos = self.render_camera.node.absolute_translation + c_forward = ( + self.render_camera.node.absolute_transformation().transform_vector( + mn.Vector3(0, 0, -1) + ) + ) + for receptacle in self.receptacles: + rec_unique_name = receptacle.unique_name + # filter all non-active receptacles + if ( + self.rec_filter_data is not None + and not self.show_filtered + and rec_unique_name not in self.rec_filter_data["active"] + ): + continue + + rec_dat = None + if self.cpo_initialized: + rec_dat = self._cpo.gt_data[self.rec_to_poth[receptacle]][ + "receptacles" + ][receptacle.name] + + r_trans = receptacle.get_global_transform(self.sim) + # display point samples for selected object + if ( + rec_dat is not None + and self.display_selected_stability_samples + and self.selected_object is not None + and self.selected_object.handle == receptacle.parent_object_handle + ): + # display colored circles for stability samples on the selected object + point_metric_dat = rec_dat["shape_id_results"]["gt"][ + "access_results" + ]["receptacle_point_access_scores"] + if self.rec_color_mode == RecColorMode.GT_STABILITY: + point_metric_dat = rec_dat["shape_id_results"]["gt"][ + "stability_results" + ]["point_stabilities"] + elif self.rec_color_mode == RecColorMode.PR_STABILITY: + point_metric_dat = rec_dat["shape_id_results"]["pr0"][ + "stability_results" + ]["point_stabilities"] + elif self.rec_color_mode == RecColorMode.PR_ACCESS: + point_metric_dat = rec_dat["shape_id_results"]["pr0"][ + "access_results" + ]["receptacle_point_access_scores"] + + for point_metric, point in zip( + point_metric_dat, + rec_dat["sample_points"], + ): + self.sim.get_debug_line_render().draw_circle( + translation=r_trans.transform_point(point), + radius=0.02, + normal=mn.Vector3(0, 1, 0), + color=rg_lerp.at(point_metric), + num_segments=12, + ) + + rec_obj = sutils.get_obj_from_handle( + self.sim, receptacle.parent_object_handle + ) + key_points = [r_trans.translation] + key_points.extend( + sutils.get_bb_corners(rec_obj.root_scene_node.cumulative_bb) + ) + + in_view = False + for ix, key_point in enumerate(key_points): + r_pos = key_point + if ix > 0: + r_pos = rec_obj.transformation.transform_point(key_point) + c_to_r = r_pos - c_pos + # only display receptacles within 8 meters centered in view + if ( + c_to_r.length() < 8 + and mn.math.dot((c_to_r).normalized(), c_forward) > 0.7 + ): + in_view = True + break + if in_view: + # handle coloring + rec_color = None + if self.selected_rec == receptacle: + # white + rec_color = mn.Color4.cyan() + elif ( + self.rec_filter_data is not None + ) and self.rec_color_mode == RecColorMode.FILTERING: + # blue indicates no filter data for the receptacle, it may be newer than the filter file. + rec_color = mn.Color4.blue() + if rec_unique_name in self.rec_filter_data["active"]: + rec_color = mn.Color4.green() + elif ( + rec_unique_name in self.rec_filter_data["manually_filtered"] + ): + rec_color = mn.Color4.yellow() + elif rec_unique_name in self.rec_filter_data["access_filtered"]: + rec_color = mn.Color4.red() + elif ( + rec_unique_name + in self.rec_filter_data["stability_filtered"] + ): + rec_color = mn.Color4.magenta() + elif rec_unique_name in self.rec_filter_data["height_filtered"]: + # orange + rec_color = mn.Color4(1.0, 0.66, 0.0, 1.0) + elif ( + self.cpo_initialized + and self.rec_color_mode != RecColorMode.DEFAULT + ): + if self.rec_color_mode == RecColorMode.GT_STABILITY: + rec_color = rg_lerp.at( + rec_dat["shape_id_results"]["gt"]["stability_results"][ + "success_ratio" + ] + ) + elif self.rec_color_mode == RecColorMode.GT_ACCESS: + rec_color = rg_lerp.at( + rec_dat["shape_id_results"]["gt"]["access_results"][ + "receptacle_access_score" + ] + ) + elif self.rec_color_mode == RecColorMode.PR_STABILITY: + rec_color = rg_lerp.at( + rec_dat["shape_id_results"]["pr0"]["stability_results"][ + "success_ratio" + ] + ) + elif self.rec_color_mode == RecColorMode.PR_ACCESS: + rec_color = rg_lerp.at( + rec_dat["shape_id_results"]["pr0"]["access_results"][ + "receptacle_access_score" + ] + ) + + receptacle.debug_draw(self.sim, color=rec_color) + if True: + dblr = self.sim.get_debug_line_render() + t_form = receptacle.get_global_transform(self.sim) + dblr.push_transform(t_form) + dblr.draw_transformed_line( + mn.Vector3(0), receptacle.up, mn.Color4.cyan() + ) + dblr.pop_transform() + # mouse raycast circle + white = mn.Color4(mn.Vector3(1.0), 1.0) + if self.mouse_cast_results is not None and self.mouse_cast_results.has_hits(): + self.sim.get_debug_line_render().draw_circle( + translation=self.mouse_cast_results.hits[0].point, + radius=0.005, + color=white, + normal=self.mouse_cast_results.hits[0].normal, + ) def draw_event( self, @@ -272,6 +844,10 @@ def draw_event( Calls continuously to re-render frames and swap the two frame buffers at a fixed rate. """ + # until cpo initialization is finished, keep checking + if not self.cpo_initialized: + self.cpo_initialized = _cpo_initialized() + agent_acts_per_sec = self.fps mn.gl.default_framebuffer.clear( @@ -290,6 +866,17 @@ def draw_event( self.simulate_single_step = False if simulation_call is not None: simulation_call() + # compute object stability after physics step + self.num_unstable_objects = 0 + for obj_initial_state, obj in zip( + self.clutter_object_initial_states, self.clutter_object_instances + ): + translation_error = ( + obj_initial_state[0] - obj.translation + ).length() + if translation_error > 0.1: + self.num_unstable_objects += 1 + if global_call is not None: global_call() @@ -362,7 +949,26 @@ def default_agent_config(self) -> habitat_sim.agent.AgentConfiguration: ) return agent_config - def reconfigure_sim(self) -> None: + def initialize_clutter_object_set(self) -> None: + """ + Get the template handles for configured clutter objects. + """ + + self.clutter_object_handles = [] + for obj_name in self.clutter_object_set: + matching_handles = ( + self.sim.metadata_mediator.object_template_manager.get_template_handles( + obj_name + ) + ) + assert ( + len(matching_handles) > 0 + ), f"No matching template for '{obj_name}' in the dataset." + self.clutter_object_handles.append(matching_handles[0]) + + def reconfigure_sim( + self, mm: Optional[habitat_sim.metadata.MetadataMediator] = None + ) -> None: """ Utilizes the current `self.sim_settings` to configure and set up a new `habitat_sim.Simulator`, and then either starts a simulation instance, or replaces @@ -370,6 +976,7 @@ def reconfigure_sim(self) -> None: """ # configure our sim_settings but then set the agent to our default self.cfg = make_cfg(self.sim_settings) + self.cfg.metadata_mediator = mm self.agent_id: int = self.sim_settings["default_agent"] self.cfg.agents[self.agent_id] = self.default_agent_config() @@ -398,6 +1005,10 @@ def reconfigure_sim(self) -> None: self.tiled_sims[i].config.sim_cfg.scene_id = "NONE" self.tiled_sims[i].reconfigure(self.cfg) + # #resave scene instance + # self.sim.save_current_scene_config(overwrite=True) + # sys. exit() + # post reconfigure self.default_agent = self.sim.get_agent(self.agent_id) self.render_camera = self.default_agent.scene_node.node_sensor_suite.get( @@ -428,6 +1039,10 @@ def reconfigure_sim(self) -> None: for composite_file in sim_settings["composite_files"]: self.replay_renderer.preload_file(composite_file) + self.ao_link_map = sutils.get_ao_link_id_map(self.sim) + + self.dbv = DebugVisualizer(self.sim) + Timer.start() self.step = -1 @@ -457,10 +1072,9 @@ def move_and_look(self, repetitions: int) -> None: if repetitions == 0: return - key = Application.KeyEvent.Key agent = self.sim.agents[self.agent_id] - press: Dict[key.key, bool] = self.pressed - act: Dict[key.key, str] = self.key_to_action + press: Dict[Application.KeyEvent.Key.key, bool] = self.pressed + act: Dict[Application.KeyEvent.Key.key, str] = self.key_to_action action_queue: List[str] = [act[k] for k, v in press.items() if v] @@ -480,6 +1094,170 @@ def invert_gravity(self) -> None: gravity: mn.Vector3 = self.sim.get_gravity() * -1 self.sim.set_gravity(gravity) + def cycleScene(self, change_scene: bool, shift_pressed: bool): + if change_scene: + # cycle the active scene from the set available in MetadataMediator + inc = -1 if shift_pressed else 1 + scene_ids = self.sim.metadata_mediator.get_scene_handles() + cur_scene_index = 0 + if self.sim_settings["scene"] not in scene_ids: + matching_scenes = [ + (ix, x) + for ix, x in enumerate(scene_ids) + if self.sim_settings["scene"] in x + ] + if not matching_scenes: + logger.warning( + f"The current scene, '{self.sim_settings['scene']}', is not in the list, starting cycle at index 0." + ) + else: + cur_scene_index = matching_scenes[0][0] + else: + cur_scene_index = scene_ids.index(self.sim_settings["scene"]) + + next_scene_index = min(max(cur_scene_index + inc, 0), len(scene_ids) - 1) + self.sim_settings["scene"] = scene_ids[next_scene_index] + self.reconfigure_sim() + logger.info(f"Reconfigured simulator for scene: {self.sim_settings['scene']}") + + def clear_furniture_joint_states(self): + """ + Clear all furniture object joint states. + """ + for ao in ( + self.sim.get_articulated_object_manager() + .get_objects_by_handle_substring() + .values() + ): + j_pos = ao.joint_positions + ao.joint_positions = [0.0 for _ in range(len(j_pos))] + j_vel = ao.joint_velocities + ao.joint_velocities = [0.0 for _ in range(len(j_vel))] + + def check_rec_accessibility( + self, rec: hab_receptacle.Receptacle, max_height: float = 1.2, clean_up=True + ) -> Tuple[bool, str]: + """ + Use unoccluded navmesh snap to check whether a Receptacle is accessible. + """ + print(f"Checking Receptacle accessibility for {rec.unique_name}") + + # first check if the receptacle is close enough to the navmesh + rec_global_keypoints = sutils.get_global_keypoints_from_bb( + rec.bounds, rec.get_global_transform(self.sim) + ) + floor_point = None + for keypoint in rec_global_keypoints: + floor_point = self.sim.pathfinder.snap_point( + keypoint, island_index=self.largest_island_ix + ) + if not np.isnan(floor_point[0]): + break + if np.isnan(floor_point[0]): + print(" - Receptacle too far from active navmesh boundary.") + return False, "access_filtered" + + # then check that the height is acceptable + rec_min = min(rec_global_keypoints, key=lambda x: x[1]) + if rec_min[1] - floor_point[1] > max_height: + print( + f" - Receptacle exceeds maximum height {rec_min[1]-floor_point[1]} vs {max_height}." + ) + return False, "height_filtered" + + # try to sample 10 objects on the receptacle + target_number = 10 + obj_samp = ObjectSampler( + self.clutter_object_handles, + ["rec set"], + orientation_sample="up", + num_objects=(1, target_number), + ) + obj_samp.max_sample_attempts = len(self.clutter_object_handles) + obj_samp.max_placement_attempts = 10 + obj_samp.target_objects_number = target_number + rec_set_unique_names = [rec.unique_name] + rec_set_obj = hab_receptacle.ReceptacleSet( + "rec set", [""], [], rec_set_unique_names, [] + ) + recep_tracker = hab_receptacle.ReceptacleTracker( + {}, + {"rec set": rec_set_obj}, + ) + new_objs = obj_samp.sample(self.sim, recep_tracker, [], snap_down=True) + + # if we can't sample objects, this receptacle is out + if len(new_objs) == 0: + print(" - failed to sample any objects.") + return False, "access_filtered" + print(f" - sampled {len(new_objs)} / {target_number} objects.") + + for obj, _rec in new_objs: + self.clutter_object_instances.append(obj) + self.clutter_object_initial_states.append((obj.translation, obj.rotation)) + + # now try unoccluded navmesh snapping to the objects to test accessibility + obj_positions = [obj.translation for obj, _ in new_objs] + for obj, _ in new_objs: + obj.translation += mn.Vector3(100, 0, 0) + failure_count = 0 + + for o_ix, (obj, _) in enumerate(new_objs): + obj.translation = obj_positions[o_ix] + snap_point = unoccluded_navmesh_snap( + obj.translation, + 1.3, + self.sim.pathfinder, + self.sim, + obj.object_id, + self.largest_island_ix, + ) + # self.dbv.look_at(look_at=obj.translation, look_from=snap_point) + # self.dbv.get_observation().show() + if snap_point is None: + failure_count += 1 + obj.translation += mn.Vector3(100, 0, 0) + for o_ix, (obj, _) in enumerate(new_objs): + obj.translation = obj_positions[o_ix] + failure_rate = (float(failure_count) / len(new_objs)) * 100 + print(f" - failure_rate = {failure_rate}") + print( + f" - accessibility rate = {len(new_objs)-failure_count}|{len(new_objs)} ({100-failure_rate}%)" + ) + + accessible = failure_rate < 20 # 80% accessibility required + + if clean_up: + # removing all clutter objects currently + rom = self.sim.get_rigid_object_manager() + print(f"Removing {len(self.clutter_object_instances)} clutter objects.") + for obj in self.clutter_object_instances: + rom.remove_object_by_handle(obj.handle) + self.clutter_object_initial_states.clear() + self.clutter_object_instances.clear() + + if not accessible: + return False, "access_filtered" + + return True, "active" + + def set_filter_status_for_rec( + self, rec: hab_receptacle.Receptacle, filter_status: str + ) -> None: + filter_types = [ + "access_filtered", + "stability_filtered", + "height_filtered", + "manually_filtered", + "active", + ] + assert filter_status in filter_types + filtered_rec_name = rec.unique_name + for filter_type in filter_types: + if filtered_rec_name in self.rec_filter_data[filter_type]: + self.rec_filter_data[filter_type].remove(filtered_rec_name) + self.rec_filter_data[filter_status].append(filtered_rec_name) + def key_press_event(self, event: Application.KeyEvent) -> None: """ Handles `Application.KeyEvent` on a key press by performing the corresponding functions. @@ -499,45 +1277,13 @@ def key_press_event(self, event: Application.KeyEvent) -> None: self.exit_event(Application.ExitEvent) return - elif key == pressed.H: - self.print_help_text() - elif key == pressed.J: - logger.info( - f"Toggle Region Draw from {self.semantic_region_debug_draw } to {not self.semantic_region_debug_draw}" - ) - # Toggle visualize semantic bboxes. Currently only regions supported - self.semantic_region_debug_draw = not self.semantic_region_debug_draw + elif key == pressed.SIX: + # Reset mouse wheel FOV zoom + self.render_camera.reset_zoom() elif key == pressed.TAB: - # NOTE: (+ALT) - reconfigure without cycling scenes - if not alt_pressed: - # cycle the active scene from the set available in MetadataMediator - inc = -1 if shift_pressed else 1 - scene_ids = self.sim.metadata_mediator.get_scene_handles() - cur_scene_index = 0 - if self.sim_settings["scene"] not in scene_ids: - matching_scenes = [ - (ix, x) - for ix, x in enumerate(scene_ids) - if self.sim_settings["scene"] in x - ] - if not matching_scenes: - logger.warning( - f"The current scene, '{self.sim_settings['scene']}', is not in the list, starting cycle at index 0." - ) - else: - cur_scene_index = matching_scenes[0][0] - else: - cur_scene_index = scene_ids.index(self.sim_settings["scene"]) - - next_scene_index = min( - max(cur_scene_index + inc, 0), len(scene_ids) - 1 - ) - self.sim_settings["scene"] = scene_ids[next_scene_index] - self.reconfigure_sim() - logger.info( - f"Reconfigured simulator for scene: {self.sim_settings['scene']}" - ) + # Cycle through scenes + self.cycleScene(True, shift_pressed=shift_pressed) elif key == pressed.SPACE: if not self.sim.config.sim_cfg.enable_physics: @@ -572,58 +1318,36 @@ def key_press_event(self, event: Application.KeyEvent) -> None: self.contact_debug_draw = True # TODO: add a nice log message with concise contact pair naming. - elif key == pressed.T: - # load URDF - fixed_base = alt_pressed - urdf_file_path = "" - if shift_pressed and self.cached_urdf: - urdf_file_path = self.cached_urdf + elif key == pressed.F: + # toggle, load(+ALT), or save(+SHIFT) filtering + if shift_pressed and self.rec_filter_data is not None: + self.export_filtered_recs(self.rec_filter_path) + elif alt_pressed: + self.load_filtered_recs(self.rec_filter_path) else: - urdf_file_path = input("Load URDF: provide a URDF filepath:").strip() - - if not urdf_file_path: - logger.warn("Load URDF: no input provided. Aborting.") - elif not urdf_file_path.endswith((".URDF", ".urdf")): - logger.warn("Load URDF: input is not a URDF. Aborting.") - elif os.path.exists(urdf_file_path): - self.cached_urdf = urdf_file_path - aom = self.sim.get_articulated_object_manager() - ao = aom.add_articulated_object_from_urdf( - urdf_file_path, - fixed_base, - 1.0, - 1.0, - True, - maintain_link_order=False, - intertia_from_urdf=False, - ) - ao.translation = ( - self.default_agent.scene_node.transformation.transform_point( - [0.0, 1.0, -1.5] - ) - ) - # check removal and auto-creation - joint_motor_settings = habitat_sim.physics.JointMotorSettings( - position_target=0.0, - position_gain=1.0, - velocity_target=0.0, - velocity_gain=1.0, - max_impulse=1000.0, - ) - existing_motor_ids = ao.existing_joint_motor_ids - for motor_id in existing_motor_ids: - ao.remove_joint_motor(motor_id) - ao.create_all_motors(joint_motor_settings) - else: - logger.warn("Load URDF: input file not found. Aborting.") + self.show_filtered = not self.show_filtered + print(f"self.show_filtered = {self.show_filtered}") + + elif key == pressed.H: + self.print_help_text() + + elif key == pressed.J: + self.clear_furniture_joint_states() + + elif key == pressed.K: + new_state_idx = (self.semantic_region_debug_draw_state + 1) % len( + self.semantic_region_debug_draw_choices + ) + logger.info( + f"Change Region Draw from {self.semantic_region_debug_draw_choices[self.semantic_region_debug_draw_state]} to {self.semantic_region_debug_draw_choices[new_state_idx]}" + ) + # Increment visualize semantic bboxes. Currently only regions supported + self.semantic_region_debug_draw_state = new_state_idx elif key == pressed.M: self.cycle_mouse_mode() logger.info(f"Command: mouse mode set to {self.mouse_interaction}") - elif key == pressed.V: - self.invert_gravity() - logger.info("Command: gravity inverted") elif key == pressed.N: # (default) - toggle navmesh visualization # NOTE: (+ALT) - re-sample the agent position on the NavMesh @@ -632,8 +1356,12 @@ def key_press_event(self, event: Application.KeyEvent) -> None: logger.info("Command: resample agent state from navmesh") if self.sim.pathfinder.is_loaded: new_agent_state = habitat_sim.AgentState() + + print(f"Largest indoor island index = {self.largest_island_ix}") new_agent_state.position = ( - self.sim.pathfinder.get_random_navigable_point() + self.sim.pathfinder.get_random_navigable_point( + island_index=self.largest_island_ix + ) ) new_agent_state.rotation = quat_from_angle_axis( self.sim.random.uniform_float(0, 2.0 * np.pi), @@ -654,6 +1382,199 @@ def key_press_event(self, event: Application.KeyEvent) -> None: else: logger.warn("Warning: recompute navmesh first") + elif key == pressed.O: + if shift_pressed: + # move non-proxy objects in/out of visible space + self.original_objs_visible = not self.original_objs_visible + print(f"self.original_objs_visible = {self.original_objs_visible}") + if not self.original_objs_visible: + for _obj_handle, obj in ( + self.sim.get_rigid_object_manager() + .get_objects_by_handle_substring() + .items() + ): + if self.proxy_obj_postfix not in obj.creation_attributes.handle: + obj.motion_type = habitat_sim.physics.MotionType.KINEMATIC + obj.translation = obj.translation + mn.Vector3(200, 0, 0) + obj.motion_type = habitat_sim.physics.MotionType.STATIC + else: + for _obj_handle, obj in ( + self.sim.get_rigid_object_manager() + .get_objects_by_handle_substring() + .items() + ): + if self.proxy_obj_postfix not in obj.creation_attributes.handle: + obj.motion_type = habitat_sim.physics.MotionType.KINEMATIC + obj.translation = obj.translation - mn.Vector3(200, 0, 0) + obj.motion_type = habitat_sim.physics.MotionType.STATIC + else: + if self.col_proxy_objs is None: + self.col_proxy_objs = [] + for _obj_handle, obj in ( + self.sim.get_rigid_object_manager() + .get_objects_by_handle_substring() + .items() + ): + if self.proxy_obj_postfix not in obj.creation_attributes.handle: + # add a new proxy object + self.col_proxy_objs.append(self.add_col_proxy_object(obj)) + else: + self.col_proxies_visible = not self.col_proxies_visible + print(f"self.col_proxies_visible = {self.col_proxies_visible}") + + # make the proxies visible or not by moving them + if not self.col_proxies_visible: + for obj in self.col_proxy_objs: + obj.translation = obj.translation + mn.Vector3(200, 0, 0) + else: + for obj in self.col_proxy_objs: + obj.translation = obj.translation - mn.Vector3(200, 0, 0) + + elif key == pressed.R: + # Reload current scene + self.cycleScene(False, shift_pressed=shift_pressed) + + elif key == pressed.T: + if shift_pressed: + # open all the AO default links + all_objects = sutils.get_all_objects(self.sim) + aos = [ + obj + for obj in all_objects + if isinstance(obj, habitat_sim.physics.ManagedArticulatedObject) + ] + for ao in aos: + default_link = sutils.get_ao_default_link(ao, True) + sutils.open_link(ao, default_link) + # compute and set the receptacle filters + for rix, rec in enumerate(self.receptacles): + rec_accessible, filter_type = self.check_rec_accessibility(rec) + self.set_filter_status_for_rec(rec, filter_type) + print(f"-- progress = {rix}/{len(self.receptacles)} --") + else: + if self.selected_rec is not None: + rec_accessible, filter_type = self.check_rec_accessibility( + self.selected_rec, clean_up=False + ) + self.set_filter_status_for_rec(self.selected_rec, filter_type) + else: + print("No selected receptacle, can't test accessibility.") + # self.modify_param_from_term() + + # load URDF + # fixed_base = alt_pressed + # urdf_file_path = "" + # if shift_pressed and self.cached_urdf: + # urdf_file_path = self.cached_urdf + # else: + # urdf_file_path = input("Load URDF: provide a URDF filepath:").strip() + # if not urdf_file_path: + # logger.warn("Load URDF: no input provided. Aborting.") + # elif not urdf_file_path.endswith((".URDF", ".urdf")): + # logger.warn("Load URDF: input is not a URDF. Aborting.") + # elif os.path.exists(urdf_file_path): + # self.cached_urdf = urdf_file_path + # aom = self.sim.get_articulated_object_manager() + # ao = aom.add_articulated_object_from_urdf( + # urdf_file_path, + # fixed_base, + # 1.0, + # 1.0, + # True, + # maintain_link_order=False, + # intertia_from_urdf=False, + # ) + # ao.translation = ( + # self.default_agent.scene_node.transformation.transform_point( + # [0.0, 1.0, -1.5] + # ) + # ) + # # check removal and auto-creation + # joint_motor_settings = habitat_sim.physics.JointMotorSettings( + # position_target=0.0, + # position_gain=1.0, + # velocity_target=0.0, + # velocity_gain=1.0, + # max_impulse=1000.0, + # ) + # existing_motor_ids = ao.existing_joint_motor_ids + # for motor_id in existing_motor_ids: + # ao.remove_joint_motor(motor_id) + # ao.create_all_motors(joint_motor_settings) + # else: + # logger.warn("Load URDF: input file not found. Aborting.") + + elif key == pressed.U: + rom = self.sim.get_rigid_object_manager() + # add objects to the selected receptacle or remove al objects + if shift_pressed: + # remove all + print(f"Removing {len(self.clutter_object_instances)} clutter objects.") + for obj in self.clutter_object_instances: + rom.remove_object_by_handle(obj.handle) + self.clutter_object_initial_states.clear() + self.clutter_object_instances.clear() + else: + # try to sample an object from the selected object receptacles + rec_set = None + if alt_pressed: + # use all active filter recs + rec_set = [ + rec + for rec in self.receptacles + if rec.unique_name in self.rec_filter_data["active"] + ] + elif self.selected_rec is not None: + rec_set = [self.selected_rec] + elif self.selected_object is not None: + rec_set = [ + rec + for rec in self.receptacles + if self.selected_object.handle == rec.parent_object_handle + ] + if rec_set is not None: + rec_set_unique_names = [rec.unique_name for rec in rec_set] + obj_samp = ObjectSampler( + self.clutter_object_handles, + ["rec set"], + orientation_sample="up", + num_objects=(1, 10), + ) + obj_samp.receptacle_instances = self.receptacles + rec_set_obj = hab_receptacle.ReceptacleSet( + "rec set", [""], [], rec_set_unique_names, [] + ) + recep_tracker = hab_receptacle.ReceptacleTracker( + {}, + {"rec set": rec_set_obj}, + ) + new_objs = obj_samp.sample( + self.sim, recep_tracker, [], snap_down=True + ) + for obj, rec in new_objs: + self.clutter_object_instances.append(obj) + self.clutter_object_initial_states.append( + (obj.translation, obj.rotation) + ) + print(f"Sampled '{obj.handle}' in '{rec.unique_name}'") + else: + print("No object selected, cannot sample clutter.") + + elif key == pressed.V: + # load receptacles and toggle visibilty or color mode (+SHIFT) + if self.receptacles is None: + self.load_receptacles() + + if shift_pressed: + self.rec_color_mode = RecColorMode( + (self.rec_color_mode.value + 1) % len(RecColorMode) + ) + print(f"self.rec_color_mode = {self.rec_color_mode}") + self.display_receptacles = True + else: + self.display_receptacles = not self.display_receptacles + print(f"self.display_receptacles = {self.display_receptacles}") + # update map of moving/looking keys which are currently pressed if key in self.pressed: self.pressed[key] = True @@ -680,6 +1601,11 @@ def mouse_move_event(self, event: Application.MouseMoveEvent) -> None: mouse button to steer the agent's facing direction. When in GRAB mode, continues to update the grabber's object position with our agents position. """ + + render_camera = self.render_camera.render_camera + ray = render_camera.unproject(self.get_mouse_position(event.position)) + self.mouse_cast_results = self.sim.cast_ray(ray=ray) + button = Application.MouseMoveEvent.Buttons # if interactive mode -> LOOK MODE if event.buttons == button.LEFT and self.mouse_interaction == MouseMode.LOOK: @@ -712,6 +1638,9 @@ def mouse_press_event(self, event: Application.MouseEvent) -> None: """ button = Application.MouseEvent.Button physics_enabled = self.sim.get_physics_simulation_library() + mod = Application.InputEvent.Modifier + shift_pressed = bool(event.modifiers & mod.SHIFT) + alt_pressed = bool(event.modifiers & mod.ALT) # if interactive mode is True -> GRAB MODE if self.mouse_interaction == MouseMode.GRAB and physics_enabled: @@ -720,83 +1649,168 @@ def mouse_press_event(self, event: Application.MouseEvent) -> None: raycast_results = self.sim.cast_ray(ray=ray) if raycast_results.has_hits(): - hit_object, ao_link = -1, -1 + ao_link = -1 hit_info = raycast_results.hits[0] if hit_info.object_id > habitat_sim.stage_id: - # we hit an non-staged collision object - ro_mngr = self.sim.get_rigid_object_manager() - ao_mngr = self.sim.get_articulated_object_manager() - ao = ao_mngr.get_object_by_id(hit_info.object_id) - ro = ro_mngr.get_object_by_id(hit_info.object_id) - - if ro: - # if grabbed an object - hit_object = hit_info.object_id - object_pivot = ro.transformation.inverted().transform_point( - hit_info.point + obj = sutils.get_obj_from_id( + self.sim, hit_info.object_id, self.ao_link_map + ) + + if obj is None: + raise AssertionError( + "hit object_id is not valid. Did not find object or link." ) - object_frame = ro.rotation.inverted() - elif ao: - # if grabbed the base link - hit_object = hit_info.object_id - object_pivot = ao.transformation.inverted().transform_point( + + if obj.object_id == hit_info.object_id: + # ro or ao base + object_pivot = obj.transformation.inverted().transform_point( hit_info.point ) - object_frame = ao.rotation.inverted() - else: - for ao_handle in ao_mngr.get_objects_by_handle_substring(): - ao = ao_mngr.get_object_by_handle(ao_handle) - link_to_obj_ids = ao.link_object_ids - - if hit_info.object_id in link_to_obj_ids: - # if we got a link - ao_link = link_to_obj_ids[hit_info.object_id] - object_pivot = ( - ao.get_link_scene_node(ao_link) - .transformation.inverted() - .transform_point(hit_info.point) - ) - object_frame = ao.get_link_scene_node( - ao_link - ).rotation.inverted() - hit_object = ao.object_id - break - # done checking for AO - - if hit_object >= 0: - node = self.default_agent.scene_node - constraint_settings = physics.RigidConstraintSettings() - - constraint_settings.object_id_a = hit_object - constraint_settings.link_id_a = ao_link - constraint_settings.pivot_a = object_pivot - constraint_settings.frame_a = ( - object_frame.to_matrix() @ node.rotation.to_matrix() + object_frame = obj.rotation.inverted() + elif isinstance(obj, physics.ManagedArticulatedObject): + # link + ao_link = obj.link_object_ids[hit_info.object_id] + object_pivot = ( + obj.get_link_scene_node(ao_link) + .transformation.inverted() + .transform_point(hit_info.point) ) - constraint_settings.frame_b = node.rotation.to_matrix() - constraint_settings.pivot_b = hit_info.point + object_frame = obj.get_link_scene_node( + ao_link + ).rotation.inverted() + + print(f"Grabbed object {obj.handle}") + if ao_link >= 0: + print(f" link id {ao_link}") + + # setup the grabbing constraints + node = self.default_agent.scene_node + constraint_settings = physics.RigidConstraintSettings() + + constraint_settings.object_id_a = obj.object_id + constraint_settings.link_id_a = ao_link + constraint_settings.pivot_a = object_pivot + constraint_settings.frame_a = ( + object_frame.to_matrix() @ node.rotation.to_matrix() + ) + constraint_settings.frame_b = node.rotation.to_matrix() + constraint_settings.pivot_b = hit_info.point - # by default use a point 2 point constraint - if event.button == button.RIGHT: - constraint_settings.constraint_type = ( - physics.RigidConstraintType.Fixed - ) + # by default use a point 2 point constraint + if event.button == button.RIGHT: + constraint_settings.constraint_type = ( + physics.RigidConstraintType.Fixed + ) - grip_depth = ( - hit_info.point - render_camera.node.absolute_translation - ).length() + grip_depth = ( + hit_info.point - render_camera.node.absolute_translation + ).length() + + self.mouse_grabber = MouseGrabber( + constraint_settings, + grip_depth, + self.sim, + ) - self.mouse_grabber = MouseGrabber( - constraint_settings, - grip_depth, - self.sim, - ) - else: - logger.warn("Oops, couldn't find the hit object. That's odd.") # end if didn't hit the scene # end has raycast hit # end has physics enabled + elif ( + self.mouse_interaction == MouseMode.LOOK + and physics_enabled + and self.mouse_cast_results is not None + and self.mouse_cast_results.has_hits() + and event.button == button.RIGHT + ): + self.selected_object = None + self.selected_rec = None + hit_id = self.mouse_cast_results.hits[0].object_id + # right click in look mode to print object information + if hit_id == habitat_sim.stage_id: + print("This is the stage.") + else: + obj = sutils.get_obj_from_id(self.sim, hit_id) + link_id = None + if obj.object_id != hit_id: + # this is a link + link_id = obj.link_object_ids[hit_id] + self.selected_object = obj + print(f"Object: {obj.handle}") + if self.receptacles is not None: + for rec in self.receptacles: + if rec.parent_object_handle == obj.handle: + print(f" - Receptacle: {rec.name}") + if shift_pressed: + if obj.handle not in self.poh_to_rec: + new_rec = hab_receptacle.AnyObjectReceptacle( + obj.handle + "_aor", + parent_object_handle=obj.handle, + parent_link=link_id, + ) + self.receptacles.append(new_rec) + self.poh_to_rec[obj.handle] = [new_rec] + self.rec_to_poth[new_rec] = obj.creation_attributes.handle + self.selected_rec = self.get_closest_receptacle( + self.mouse_cast_results.hits[0].point + ) + if self.selected_rec is not None: + print(f"Selected Receptacle: {self.selected_rec.name}") + elif alt_pressed: + filtered_rec = self.get_closest_receptacle( + self.mouse_cast_results.hits[0].point + ) + if filtered_rec is not None: + filtered_rec_name = filtered_rec.unique_name + print(f"Modified Receptacle Filter State: {filtered_rec_name}") + if ( + filtered_rec_name + in self.rec_filter_data["manually_filtered"] + ): + print(" remove from manual filter") + # this was manually filtered, remove it and try to make active + self.rec_filter_data["manually_filtered"].remove( + filtered_rec_name + ) + add_to_active = True + for other_out_set in [ + "access_filtered", + "stability_filtered", + "height_filtered", + ]: + if ( + filtered_rec_name + in self.rec_filter_data[other_out_set] + ): + print(f" is in {other_out_set}") + add_to_active = False + break + if add_to_active: + print(" is active") + self.rec_filter_data["active"].append(filtered_rec_name) + elif filtered_rec_name in self.rec_filter_data["active"]: + print(" remove from active, add manual filter") + # this was active, remove it and mark manually filtered + self.rec_filter_data["active"].remove(filtered_rec_name) + self.rec_filter_data["manually_filtered"].append( + filtered_rec_name + ) + else: + print(" add to manual filter, but has other filter") + # this is already filtered, but add it to manual filters + self.rec_filter_data["manually_filtered"].append( + filtered_rec_name + ) + elif isinstance(obj, habitat_sim.physics.ManagedArticulatedObject): + # get the default link + default_link = sutils.get_ao_default_link(obj, True) + if default_link is None: + print("Selected AO has no default link.") + else: + if sutils.link_is_open(obj, default_link, 0.05): + sutils.close_link(obj, default_link) + else: + sutils.open_link(obj, default_link) self.previous_mouse_point = self.get_mouse_position(event.position) self.redraw() @@ -915,6 +1929,11 @@ def navmesh_config_and_recompute(self) -> None: self.sim.pathfinder, self.navmesh_settings, ) + self.largest_island_ix = get_largest_island_index( + pathfinder=self.sim.pathfinder, + sim=self.sim, + allow_outdoor=False, + ) def exit_event(self, event: Application.ExitEvent): """ @@ -947,9 +1966,11 @@ def draw_text(self, sensor_spec): self.window_text.render( f""" {self.fps} FPS +Scene ID : {os.path.split(self.cfg.sim_cfg.scene_id)[1].split('.scene_instance')[0]} Sensor Type: {sensor_type_string} Sensor Subtype: {sensor_subtype_string} Mouse Interaction Mode: {mouse_mode_string} +Unstable Objects: {self.num_unstable_objects} of {len(self.clutter_object_instances)} """ ) self.shader.draw(self.window_text.mesh) @@ -973,7 +1994,10 @@ def print_help_text(self) -> None: Click and drag to rotate the agent and look up/down. WHEEL: Modify orthographic camera zoom/perspective camera FOV (+SHIFT for fine grained control) - + RIGHT: + Click an object to select the object. Prints object name and attached receptacle names. Selected object displays sample points when cpo is initialized. + (+SHIFT) select a receptacle. + (+ALT) add or remove a receptacle from the "manual filter set". In GRAB mode (with 'enable-physics'): LEFT: Click and drag to pickup and move an object with a point-to-point constraint (e.g. ball joint). @@ -1006,15 +2030,26 @@ def print_help_text(self) -> None: ',': Render a Bullet collision shape debug wireframe overlay (white=active, green=sleeping, blue=wants sleeping, red=can't sleep). 'c': Run a discrete collision detection pass and render a debug wireframe overlay showing active contact points and normals (yellow=fixed length normals, red=collision distances). (+SHIFT) Toggle the contact point debug render overlay on/off. - 'j' Toggle Semantic visualization bounds (currently only Semantic Region annotations) + 'j' Clear the joint states of all articulated objects. + 'k' Toggle Semantic visualization bounds (currently only Semantic Region annotations) Object Interactions: SPACE: Toggle physics simulation on/off. '.': Take a single simulation step if not simulating continuously. - 'v': (physics) Invert gravity. - 't': Load URDF from filepath - (+SHIFT) quick re-load the previously specified URDF - (+ALT) load the URDF with fixed base + + Receptacle Evaluation Tool UI: + 'v': Load all Receptacles for the scene and toggle Receptacle visibility. + (+SHIFT) Iterate through receptacle color modes. + 'f': Toggle Receptacle view filtering. When on, only non-filtered Receptacles are visible. + (+SHIFT) Export current filter metadata to file. + (+ALT) Import filter metadata from file. + 'o': Toggle display of collision proxy shapes for the scene. + (+SHIFT) Toggle display of original render shapes (and Receptacles). + 't': CLI for modifying un-bound viewer parameters during runtime. + 'u': Sample an object placement from the currently selected object or receptacle. + (+SHIFT) Remove all previously sampled objects. + (+ALT) Sample from all "active" unfiltered Receptacles. + ===================================================== """ ) @@ -1133,6 +2168,45 @@ def next_frame() -> None: Timer.prev_frame_time = time.time() +def init_cpo_for_scene(sim_settings, mm: habitat_sim.metadata.MetadataMediator): + """ + Initialize and run the CPO for all objects in the scene. + """ + global _cpo + global _cpo_threads + + _cpo = csa.CollisionProxyOptimizer(sim_settings, None, mm) + + # get object handles from a specific scene + objects_in_scene = csa.get_objects_in_scene( + dataset_path=sim_settings["scene_dataset_config_file"], + scene_handle=sim_settings["scene"], + mm=_cpo.mm, + ) + # get a subset with receptacles defined + objects_in_scene = [ + objects_in_scene[i] + for i in range(len(objects_in_scene)) + if csa.object_has_receptacles(objects_in_scene[i], mm.object_template_manager) + ] + + def run_cpo_for_obj(obj_handle): + _cpo.setup_obj_gt(obj_handle) + _cpo.compute_receptacle_stability(obj_handle, use_gt=True) + _cpo.compute_receptacle_stability(obj_handle) + _cpo.compute_receptacle_access_metrics(obj_handle, use_gt=True) + _cpo.compute_receptacle_access_metrics(obj_handle, use_gt=False) + + # run CPO initialization multi-threaded to unblock viewer initialization and use + + threads = [] + for obj_handle in objects_in_scene: + run_cpo_for_obj(obj_handle) + # threads.append(threading.Thread(target=run_cpo_for_obj, args=(obj_handle,))) + for thread in threads: + thread.start() + + if __name__ == "__main__": import argparse @@ -1152,6 +2226,17 @@ def next_frame() -> None: metavar="DATASET", help='dataset configuration file to use (default: "default")', ) + parser.add_argument( + "--rec-filter-file", + default="./rec_filter_data.json", + type=str, + help='Receptacle filtering metadata (default: "./rec_filter_data.json")', + ) + parser.add_argument( + "--init-cpo", + action="store_true", + help="Initialize and run the CPO for the current scene.", + ) parser.add_argument( "--disable-physics", action="store_true", @@ -1184,15 +2269,21 @@ def next_frame() -> None: nargs="*", help="Composite files that the batch renderer will use in-place of simulation assets to improve memory usage and performance. If none is specified, the original scene files will be loaded from disk.", ) + parser.add_argument( + "--no-navmesh", + default=False, + action="store_true", + help="Don't build navmesh.", + ) parser.add_argument( "--width", - default=800, + default=1080, type=int, help="Horizontal resolution of the window.", ) parser.add_argument( "--height", - default=600, + default=720, type=int, help="Vertical resolution of the window.", ) @@ -1217,8 +2308,20 @@ def next_frame() -> None: sim_settings["composite_files"] = args.composite_files sim_settings["window_width"] = args.width sim_settings["window_height"] = args.height - sim_settings["default_agent_navmesh"] = False + sim_settings["rec_filter_file"] = args.rec_filter_file sim_settings["enable_hbao"] = args.hbao + sim_settings["viewer_ignore_navmesh"] = args.no_navmesh + + # don't need auto-navmesh + sim_settings["default_agent_navmesh"] = False + + mm = habitat_sim.metadata.MetadataMediator() + mm.active_dataset = sim_settings["scene_dataset_config_file"] + + # initialize the CPO. + # this will be done in parallel to viewer setup via multithreading + if args.init_cpo: + init_cpo_for_scene(sim_settings, mm) # start the application - HabitatSimInteractiveViewer(sim_settings).exec() + HabitatSimInteractiveViewer(sim_settings, mm).exec() diff --git a/tools/batched_armature_to_urdf.py b/tools/batched_armature_to_urdf.py new file mode 100644 index 0000000000..25a326c7c0 --- /dev/null +++ b/tools/batched_armature_to_urdf.py @@ -0,0 +1,239 @@ +# Copyright (c) Meta Platforms, Inc. and its affiliates. +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import argparse +import csv +import os +from collections import defaultdict +from typing import Callable, Dict, List + + +def file_endswith(filepath: str, end_str: str) -> bool: + """ + Return whether or not the file ends with a string. + """ + return filepath.endswith(end_str) + + +def find_files( + root_dir: str, discriminator: Callable[[str, str], bool], disc_str: str +) -> List[str]: + """ + Recursively find all filepaths under a root directory satisfying a particular constraint as defined by a discriminator function. + + :param root_dir: The roor directory for the recursive search. + :param discriminator: The discriminator function which takes a filepath and discriminator string and returns a bool. + + :return: The list of all absolute filepaths found satisfying the discriminator. + """ + filepaths: List[str] = [] + + if not os.path.exists(root_dir): + print(" Directory does not exist: " + str(dir)) + return filepaths + + for entry in os.listdir(root_dir): + entry_path = os.path.join(root_dir, entry) + if os.path.isdir(entry_path): + sub_dir_filepaths = find_files(entry_path, discriminator, disc_str) + filepaths.extend(sub_dir_filepaths) + # apply a user-provided discriminator function to cull filepaths + elif discriminator(entry_path, disc_str): + filepaths.append(entry_path) + return filepaths + + +def load_scene_map_file(filepath: str) -> Dict[str, List[str]]: + """ + Loads a csv file containing a mapping of scenes to objects. Returns that mapping as a Dict. + NOTE: Assumes column 0 is object id and column 1 is scene id + """ + assert os.path.exists(filepath) + assert filepath.endswith(".csv") + + scene_object_map = defaultdict(lambda: []) + with open(filepath, newline="") as f: + reader = csv.reader(f) + for rix, row in enumerate(reader): + if rix == 0: + pass + # column labels + else: + scene_id = row[1] + object_hash = row[0] + scene_object_map[scene_id].append(object_hash) + + return scene_object_map + + +def run_armature_urdf_conversion(blend_file: str, export_path: str, script_path: str): + assert os.path.exists(blend_file), f"'{blend_file}' does not exist." + os.makedirs(export_path, exist_ok=True) + base_command = f"blender {blend_file} --background --python {script_path} -- --export-path {export_path}" + # first export the meshes + os.system(base_command + " --export-meshes --fix-materials") + # then export the URDF + os.system( + base_command + + " --export-urdf --export-ao-config --round-collision-scales --fix-collision-scales" + ) + + +def get_dirs_in_dir(dirpath: str) -> List[str]: + """ + Get the directory names inside a directory path. + """ + return [ + entry.split(".glb")[0] + for entry in os.listdir(dirpath) + if os.path.isdir(os.path.join(dirpath, entry)) + ] + + +def get_dirs_in_dir_complete(dirpath: str) -> List[str]: + """ + Get the directory names inside a directory path for directories which contain: + - urdf + - ao_config.json + - at least 2 .glb files (for articulation) + TODO: check the urdf contents for all .glbs + """ + relevant_entries = [] + for entry in os.listdir(dirpath): + entry_path = os.path.join(dirpath, entry) + entry_name = entry.split(".glb")[0] + if os.path.isdir(entry_path): + contents = os.listdir(entry_path) + urdfs = [file for file in contents if file.endswith(".urdf")] + configs = [file for file in contents if file.endswith(".ao_config.json")] + glbs = [file for file in contents if file.endswith(".glb")] + if len(urdfs) > 0 and len(configs) > 0 and len(glbs) > 2: + relevant_entries.append(entry_name) + + return relevant_entries + + +# ----------------------------------------- +# Batches blender converter calls over a directory of blend files +# e.g. python tools/batched_armature_to_urdf.py --root-dir ~/Downloads/OneDrive_1_9-27-2023/ --out-dir tools/armature_out_test/ --converter-script-path tools/blender_armature_to_urdf.py +# e.g. add " --skip-strings wardrobe" to skip all objects with "wardrobe" in the filepath +# ----------------------------------------- +def main(): + parser = argparse.ArgumentParser( + description="Run Blender Armature to URDF converter for all .blend files in a directory." + ) + parser.add_argument( + "--root-dir", + type=str, + help="Path to a directory containing .blend files for conversion.", + ) + parser.add_argument( + "--out-dir", + type=str, + help="Path to a directory for exporting URDF and assets.", + ) + parser.add_argument( + "--converter-script-path", + type=str, + help="Path to blender_armature_to_urdf.py.", + default="tools/blender_armature_to_urdf.py", + ) + parser.add_argument( + "--skip-strings", + nargs="+", + type=str, + help="Substrings which indicate a path which should be skippped. E.g. an object hash '6f57e5076e491f54896631bfe4e9cfcaa08899e2' to skip that object's blend file.", + default=None, + ) + parser.add_argument( + "--scene-map-file", + type=str, + default=None, + help="Path to a csv file with scene to object mappings. Used in conjuction with 'scenes' to limit conversion to a small batch.", + ) + parser.add_argument( + "--scenes", + nargs="+", + type=str, + help="Substrings which indicate scenes which should be converted. Must be provided with a scene map file. When provided, only these scenes are converted.", + default=None, + ) + parser.add_argument( + "--no-replace", + default=False, + action="store_true", + help="If specified, cull candidate .blend files if there already exists a matching output directory for the asset.", + ) + parser.add_argument( + "--assets", + nargs="+", + type=str, + help="Asset name substrings which indicate the subset of assets which should be converted. When provided, only these assets are converted.", + default=None, + ) + + args = parser.parse_args() + root_dir = args.root_dir + assert os.path.isdir(root_dir), "directory must exist." + assert os.path.exists( + args.converter_script_path + ), f"provided script path '{args.converter_script_path}' does not exist." + + # get blend files + blend_paths = find_files(root_dir, file_endswith, ".blend") + if args.skip_strings is not None: + skipped_strings = [ + path + for skip_str in args.skip_strings + for path in blend_paths + if skip_str in path + ] + blend_paths = list(set(blend_paths) - set(skipped_strings)) + + if args.no_replace: + # out_dir_dirs = get_dirs_in_dir(args.out_dir) + out_dir_dirs = get_dirs_in_dir_complete(args.out_dir) + remaining_blend_paths = [ + blend + for blend in blend_paths + if blend.split("/")[-1].split(".")[0] not in out_dir_dirs + ] + print(f"original blends = {len(blend_paths)}") + print(f"existing dirs = {len(out_dir_dirs)}") + print(f"remaining_blend_paths = {len(remaining_blend_paths)}") + remaining_hashes = [ + blend_path.split("/")[-1] for blend_path in remaining_blend_paths + ] + print(f"remaining_hashes = {remaining_hashes}") + blend_paths = remaining_blend_paths + # use this to check, but not commit to trying again + # exit() + + if args.scene_map_file is not None and args.scenes is not None: + # load the scene map file and limit the object set by scenes + scene_object_map = load_scene_map_file(args.scene_map_file) + limited_object_paths = [] + for scene in args.scenes: + for object_id in scene_object_map[scene]: + for blend_path in blend_paths: + if object_id in blend_path: + limited_object_paths.append(blend_path) + blend_paths = list(set(limited_object_paths)) + + if args.assets is not None: + asset_blend_paths = [] + for name_str in args.assets: + asset_blend_paths.extend([path for path in blend_paths if name_str in path]) + blend_paths = asset_blend_paths + + for blend_path in blend_paths: + run_armature_urdf_conversion( + blend_file=blend_path, + export_path=args.out_dir, + script_path=args.converter_script_path, + ) + + +if __name__ == "__main__": + main() diff --git a/tools/blender_armature_to_urdf.py b/tools/blender_armature_to_urdf.py new file mode 100644 index 0000000000..c407f63d8c --- /dev/null +++ b/tools/blender_armature_to_urdf.py @@ -0,0 +1,1032 @@ +# Copyright (c) Meta Platforms, Inc. and its affiliates. +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import json +import math +import os +import xml.dom.minidom as minidom +import xml.etree.ElementTree as ET +from collections import defaultdict +from typing import Any, Dict, List, Tuple + +import bpy +from mathutils import Matrix, Quaternion, Vector + +# Colors from https://colorbrewer2.org/ +colors = [ + (166, 206, 227), + (31, 120, 180), + (178, 223, 138), + (51, 160, 44), + (251, 154, 153), + (227, 26, 28), + (253, 191, 111), + (255, 127, 0), + (202, 178, 214), + (106, 61, 154), + (255, 255, 153), + (177, 89, 40), +] +colors = [(c[0] / 255, c[1] / 255, c[2] / 255) for c in colors] +next_color = 0 + +# state vars +armature = None +counter = 0 +links = [] +joints = [] +bones_to_meshes: Dict[Any, List[Any]] = defaultdict( + list +) # maps bones to lists of mesh objects + +# constants +LINK_NAME_FORMAT = "{bone_name}" +JOINT_NAME_FORMAT = "{bone_name}" +ORIGIN_NODE_FLOAT_PRECISION = 6 +ORIGIN_NODE_FORMAT = "{{:.{0}f}} {{:.{0}f}} {{:.{0}f}}".format( + ORIGIN_NODE_FLOAT_PRECISION +) +ZERO_ORIGIN_NODE = lambda: ET.fromstring('') +INERTIA_NODE_FMT = '' +AXIS_NODE_FORMAT = lambda: ET.fromstring('') +BASE_LIMIT_NODE_STR = None + + +def round_scales(keyword: str = "collision"): + """ + Rounds shape scale vectors to 4 decimal points (millimeter) accuracy. E.g. to eliminate small mismatches in collision shape scale. + Use 'keyword' to discriminate between shape types. + """ + for obj in bpy.context.scene.objects: + if keyword in obj.name: + for i in range(3): # Iterate over X, Y, Z + obj.scale[i] = round(obj.scale[i], 4) + + +def fix_scales(keyword: str = "collision"): + """ + Flips negative scales for shapes. (E.g. collision shapes should not have negative scaling) + Use 'keyword' to discriminate between shape types. + """ + for obj in bpy.context.scene.objects: + if keyword in obj.name: + for i in range(3): # Iterate over X, Y, Z + obj.scale[i] = abs(obj.scale[i]) + + +def set_base_limit_str(effort, velocity): + """ + Default effort and velocity limits for Joints. + """ + global BASE_LIMIT_NODE_STR + BASE_LIMIT_NODE_STR = ''.format( + effort, velocity + ) + + +def deselect_all() -> None: + """ + Deselect all objects. + """ + for obj in bpy.context.selected_objects: + obj.select_set(False) + + +def is_mesh(obj) -> bool: + """ + Is the object a MESH? + """ + return obj.type == "MESH" + + +def is_collision_mesh(mesh_obj) -> bool: + """ + Is the object a collision mesh. + """ + return "collision" in mesh_obj.name + + +def is_receptacle_mesh(mesh_obj) -> bool: + """ + Is the object a receptacle mesh. + """ + return "receptacle" in mesh_obj.name + + +def get_mesh_heirarchy( + mesh_obj, + select_set: bool = True, + include_render: bool = True, + include_collison: bool = False, + include_receptacle: bool = False, +) -> List[Any]: + """ + Select all MESH objects in the heirarchy specifically targeting or omitting meshes with "collision" in the name. + + :param mesh_obj: The Blender mesh object. + :param select_set: Whether or not to select the objects as well as recording them. + :param include_render: Include render objects without qualifiers in the name. + :param include_collison: Include objects with 'collision' in the name. + :param include_receptacle: Include objects with 'receptacle' in the name. + + :return: The list of Blender mesh objects. + """ + selected_objects = [] + is_col_mesh = is_collision_mesh(mesh_obj) + is_rec_mesh = is_receptacle_mesh(mesh_obj) + if is_mesh(mesh_obj) and ( + (is_col_mesh and include_collison) + or (is_rec_mesh and include_receptacle) + or (include_render and not is_col_mesh and not is_rec_mesh) + ): + selected_objects.append(mesh_obj) + if select_set: + mesh_obj.select_set(True) + for child in mesh_obj.children: + if child.type != "ARMATURE": + selected_objects.extend( + get_mesh_heirarchy( + child, + select_set, + include_render, + include_collison, + include_receptacle, + ) + ) + return selected_objects + + +def walk_armature(this_bone, handler, kwargs_for_handler=None): + """ + Recursively apply a handler function to bone children to traverse the armature. + """ + if kwargs_for_handler is None: + kwargs_for_handler = {} + handler(this_bone, **kwargs_for_handler) + for child in this_bone.children: + walk_armature(child, handler, kwargs_for_handler) + + +def bone_info(bone): + """ + Print relevant bone info to console. + """ + print(" bone info") + print(f" name: {bone.name}") + print(" children") + for child in bone.children: + print(f" - {child.name}") + + +def node_info(node): + """ + Print relevant info about an object node. + """ + print(" node info") + print(f" name: {node.name}") + print(" children") + for child in node.children: + print(f" - {child.name}") + + +def get_origin_from_matrix(M): + """ + Construct a URDF 'origin' element from a Matrix by separating translation from rotation and converting to Euler. + """ + translation = M.to_translation() + euler = M.to_euler() + origin_xml_node = ET.Element("origin") + origin_xml_node.set("rpy", ORIGIN_NODE_FORMAT.format(euler.x, euler.y, euler.z)) + origin_xml_node.set( + "xyz", ORIGIN_NODE_FORMAT.format(translation.x, translation.y, translation.z) + ) + + return origin_xml_node + + +def get_next_color() -> Tuple[int, int, int]: + """ + Global function to get the next color in the colors list. + """ + global next_color + this_color = colors[next_color % len(colors)] + next_color += 1 + return this_color + + +def add_color_material_to_visual(color, xml_visual) -> None: + """ + Add a color material to a visual node. + """ + this_xml_material = ET.Element("material") + this_xml_material.set("name", "mat_col_{}".format(color)) + this_xml_color = ET.Element("color") + this_xml_color.set("rgba", "{:.2f} {:.2f} {:.2f} 1.0".format(*color)) + this_xml_material.append(this_xml_color) + xml_visual.append(this_xml_material) + + +def bone_to_urdf( + this_bone, + link_visuals=True, + collision_visuals=False, + joint_visuals=False, + receptacle_visuals=False, +): + """This function extracts the basic properties of the bone and populates + links and joint lists with the corresponding urdf nodes""" + + print(f"link_visuals = {link_visuals}") + print(f"collision_visuals = {collision_visuals}") + print(f"joint_visuals = {joint_visuals}") + print(f"receptacle_visuals = {receptacle_visuals}") + + global counter + + # Create the joint xml node + if this_bone.parent: + this_xml_link = create_bone_link(this_bone) + else: + this_xml_link = create_root_bone_link(this_bone) + + # NOTE: default intertia (assume overriden automatically in-engine) + this_xml_link.append(ET.fromstring(INERTIA_NODE_FMT.format(1.0, 0, 0, 1.0, 0, 1.0))) + + # NOTE: default unit mass TODO: estimate somehow? + this_xml_link.append(ET.fromstring(''.format(1.0))) + + # TODO: scrape the list of mesh filenames which would be generated by an export. + collision_objects = [] + receptacle_objects = [] + for mesh_obj in bones_to_meshes[this_bone.name]: + collision_objects.extend( + get_mesh_heirarchy( + mesh_obj, + select_set=False, + include_collison=True, + include_render=False, + ) + ) + if receptacle_visuals: + receptacle_objects.extend( + get_mesh_heirarchy( + mesh_obj, + select_set=False, + include_collison=False, + include_render=False, + include_receptacle=True, + ) + ) + if mesh_obj.parent is not None and mesh_obj.parent.type == "ARMATURE": + # this object is the mesh name for export, so use it + # Create the visual node + this_xml_visual = ET.Element("visual") + this_xml_mesh_geom = ET.Element("geometry") + this_xml_mesh = ET.Element("mesh") + this_xml_mesh.set("filename", f"{mesh_obj.name}.glb") + this_xml_mesh.set("scale", "1.0 1.0 1.0") + this_xml_mesh_geom.append(this_xml_mesh) + # NOTE: we can use zero because we reset the origin for the meshes before exporting them to glb + this_xml_visual.append(ZERO_ORIGIN_NODE()) + this_xml_visual.append(this_xml_mesh_geom) + if link_visuals: + this_xml_link.append(this_xml_visual) + + # NOTE: visual debugging tool to add a box at the joint pivot locations + if joint_visuals: + this_xml_visual = ET.Element("visual") + this_xml_test_geom = ET.Element("geometry") + this_xml_box = ET.Element("box") + this_xml_box.set("size", f"{0.1} {0.1} {0.1}") + this_xml_test_geom.append(this_xml_box) + this_xml_visual.append(ZERO_ORIGIN_NODE()) + this_xml_visual.append(this_xml_test_geom) + this_xml_link.append(this_xml_visual) + + # NOTE: color each link's collision shapes for debugging + this_color = get_next_color() + + supported_collision_shapes = [ + "collision_box", + "collision_cylinder", + "collision_sphere", + ] + + for col in collision_objects: + assert ( + len( + [ + col_name + for col_name in supported_collision_shapes + if col_name in col.name + ] + ) + == 1 + ), f"Only supporting exactly one of the following collision shapes currently: {supported_collision_shapes}. Shape name '{col.name}' unsupported." + + set_obj_origin_to_center(col) + clear_obj_transform( + col, apply=True, include_scale_apply=False, include_rot_apply=False + ) + set_obj_origin_to_xyz(col, col.parent.matrix_world.translation) + clear_obj_transform(col) + set_obj_origin_to_center(col) + + # Create the collision node + this_xml_collision = ET.Element("collision") + if collision_visuals: + this_xml_collision = ET.Element("visual") + add_color_material_to_visual(this_color, this_xml_collision) + this_xml_col_geom = ET.Element("geometry") + xml_shape = None + if "collision_box" in col.name: + this_xml_box = ET.Element("box") + box_size = col.scale + this_xml_box.set("size", f"{box_size.x} {box_size.y} {box_size.z}") + xml_shape = this_xml_box + elif "collision_cylinder" in col.name: + this_xml_cyl = ET.Element("cylinder") + scale = col.scale + # radius XY axis scale must match + assert ( + abs(scale.x - scale.y) < 0.0001 + ), f"XY dimensions must match. Used as radius. node_name=='{col.name}', x={scale.x}, y={scale.y}" + this_xml_cyl.set("radius", f"{scale.x/2.0}") + # NOTE: assume Z axis is length of the cylinder + this_xml_cyl.set("length", f"{scale.z}") + xml_shape = this_xml_cyl + elif "collision_sphere" in col.name: + this_xml_sphere = ET.Element("sphere") + scale = col.scale + # radius XYZ axis scale must match + assert ( + abs(scale.x - scale.y) < 0.0001 and abs(scale.x - scale.z) < 0.0001 + ), f"XYZ dimensions must match. Used as radius. node_name=='{col.name}', x={scale.x}, y={scale.y}, z={scale.z}" + this_xml_sphere.set("radius", f"{scale.x/2.0}") + xml_shape = this_xml_sphere + + this_xml_col_geom.append(xml_shape) + # first get the rotation + xml_origin = get_origin_from_matrix(col.matrix_local) + # then get local translation + col_link_position = col.location + xml_origin.set( + "xyz", + ORIGIN_NODE_FORMAT.format( + col_link_position.x, col_link_position.y, col_link_position.z + ), + ) + this_xml_collision.append(xml_origin) + this_xml_collision.append(this_xml_col_geom) + this_xml_link.append(this_xml_collision) + + if receptacle_visuals: + for rec_mesh in receptacle_objects: + # NOTE: color each link's collision shapes for debugging + this_color = get_next_color() + this_xml_visual = ET.Element("visual") + this_xml_geom = ET.Element("geometry") + this_xml_mesh = ET.Element("mesh") + rec_filename = rec_mesh.parent.name + "_receptacle.glb" + this_xml_mesh.set("filename", f"{rec_filename}") + this_xml_mesh.set("scale", "1.0 1.0 1.0") + this_xml_geom.append(this_xml_mesh) + # NOTE: we can use zero because we reset the origin for the meshes before exporting them to glb + this_xml_visual.append(ZERO_ORIGIN_NODE()) + this_xml_visual.append(this_xml_geom) + add_color_material_to_visual(this_color, this_xml_visual) + this_xml_link.append(this_xml_visual) + + if not this_bone.children: + pass + # We reached the end of the chain. + + counter += 1 + + +def create_root_bone_link(this_bone): + """ + Construct the root link element from a bone. + Called for bones with no parent (i.e. the root node) + """ + xml_link = ET.Element("link") + xml_link_name = this_bone.name + xml_link.set("name", xml_link_name) + links.append(xml_link) + + this_bone.name = xml_link_name + return xml_link + + +def get_origin_from_bone(bone): + """ + Construct an origin element for a joint from a bone. + """ + translation = ( + bone.matrix_local.to_translation() - bone.parent.matrix_local.to_translation() + ) + + origin_xml_node = ET.Element("origin") + origin_xml_node.set("rpy", "0 0 0") + origin_xml_node.set( + "xyz", ORIGIN_NODE_FORMAT.format(translation.x, translation.y, translation.z) + ) + + return origin_xml_node + + +def create_bone_link(this_bone): + """ + Construct Link and Joint elements from a bone. + """ + global counter + + # construct limits and joint type from animation frames + bone_limits = get_anim_limits_info(this_bone) + + # Get bone properties + parent_bone = this_bone.parent + base_joint_name = JOINT_NAME_FORMAT.format( + counter=counter, bone_name=this_bone.name + ) + + # ------------- Create joint-------------- + + joint = ET.Element("joint") + joint.set("name", base_joint_name) + + # create origin node + origin_xml_node = get_origin_from_bone(this_bone) + + # create parent node + parent_xml_node = ET.Element("parent") + parent_xml_node.set("link", parent_bone.name) + + xml_link = ET.Element("link") + xml_link_name = this_bone.name + xml_link.set("name", xml_link_name) + links.append(xml_link) + + # create child node + child_xml_node = ET.Element("child") + child_xml_node.set("link", xml_link_name) + + joint.append(parent_xml_node) + joint.append(child_xml_node) + + # create limits node + limit_node = ET.fromstring(BASE_LIMIT_NODE_STR) + + local_axis = Vector() + + # Revolute + if len(bone_limits["lower_limit"]) == 4: + joint.set("type", "revolute") + begin = Quaternion(bone_limits["lower_limit"]) + end = Quaternion(bone_limits["upper_limit"]) + rest = Quaternion() + diff = begin.rotation_difference(end) + local_axis, angle = diff.to_axis_angle() + rest_diff = begin.rotation_difference(rest) + rest_axis, rest_angle = rest_diff.to_axis_angle() + limit_node.set("lower", f"{-rest_angle}") + limit_node.set("upper", f"{angle-rest_angle}") + + # Prismatic + if len(bone_limits["lower_limit"]) == 3: + joint.set("type", "prismatic") + upper_vec = Vector(bone_limits["upper_limit"]) + lower_vec = Vector(bone_limits["lower_limit"]) + displacement = upper_vec - lower_vec + local_axis = displacement.normalized() + limit_node.set("lower", f"{-lower_vec.length}") + limit_node.set("upper", f"{upper_vec.length}") + + # NOTE: rest pose could be applied to the bone resulting in an additional rotation stored in the matrix property + rest_correction = this_bone.matrix + # NOTE: Blender bones and armature are always Y-up, so we need to rotate the axis into URDF coordinate space (Z-up) + bone_axis = this_bone.vector + to_z_up = bone_axis.rotation_difference(Vector([0, 0, 1])) + # apply all rotations to arrive at the URDF Joint axis + axis = rest_correction @ (to_z_up @ local_axis) + + xml_axis = AXIS_NODE_FORMAT() + xml_axis.set("xyz", ORIGIN_NODE_FORMAT.format(axis.x, axis.y, axis.z)) + + joint.append(xml_axis) + joint.append(limit_node) + joint.append(origin_xml_node) + joints.append(joint) + ret_link = xml_link + + return ret_link + + +# ========================================== + + +def set_obj_origin_to_center(obj) -> None: + """ + Set object origin to it's own center. + """ + deselect_all() + obj.select_set(True) + bpy.ops.object.origin_set(type="ORIGIN_GEOMETRY", center="MEDIAN") + + +def set_obj_origin_to_xyz(obj, xyz) -> None: + """ + Set object origin to a global xyz location. + """ + deselect_all() + bpy.context.scene.cursor.location = xyz + obj.select_set(True) + bpy.ops.object.origin_set(type="ORIGIN_CURSOR", center="MEDIAN") + + +def set_obj_origin_to_bone(obj, bone): + """ + Set the object origin to the bone transformation. + """ + set_obj_origin_to_xyz(obj, bone.matrix_local.translation) + + +def clear_obj_transform( + arm, apply=False, include_scale_apply=True, include_rot_apply=True +): + """ + Clear the armature transform to align it with the origin. + """ + deselect_all() + arm.select_set(True) + if apply: + bpy.ops.object.transform_apply( + location=True, rotation=include_rot_apply, scale=include_scale_apply + ) + else: + bpy.ops.object.location_clear(clear_delta=False) + + +def get_anim_limits_info(bone): + """ + Get limits info from animation action tracks. + """ + bone_limits = {"rest_pose": [], "lower_limit": [], "upper_limit": []} + if "root" in bone.name: + # no joint data defined for the root + return bone_limits + is_prismatic = False + is_revolute = False + + for ac in bpy.data.actions: + if bone.name in ac.name: + key_match = [key for key in bone_limits if key in ac.name][0] + limit_list = [] + for _fkey, fcurve in ac.fcurves.items(): + assert ( + len(fcurve.keyframe_points) == 1 + ), "Expecting one keyframe per track." + index = fcurve.array_index + value = fcurve.keyframe_points[0].co[1] + if "quaternion" in fcurve.data_path: + if len(limit_list) == 0: + limit_list = [0, 0, 0, 0] + is_revolute = True + if "location" in fcurve.data_path: + if len(limit_list) == 0: + limit_list = [0, 0, 0] + is_prismatic = True + try: + limit_list[index] = value + except IndexError: + raise Exception( + f"Failed to get limits for fcurve: bone={bone.name}, curve_key={_fkey}, index={index}. Should have exactly 3 (position) or exactly 4 (quaternion) elements." + ) + + bone_limits[key_match] = limit_list + assert ( + is_prismatic or is_revolute + ), f"Bone {bone.name} does not have animation data." + assert not ( + is_prismatic and is_revolute + ), f"Bone {bone.name} has both rotation and translation defined." + return bone_limits + + +def get_parent_bone(obj): + """ + Climb the node tree looking for the parent bone of an object. + Return the parent bone or None if a parent bone does not exist. + """ + if obj.parent_bone != "": + return armature.data.bones[obj.parent_bone] + if obj.parent is None: + return None + return get_parent_bone(obj.parent) + + +def get_root_bone(): + """ + Find the root bone. + """ + root_bone = None + for b in armature.data.bones: + if not b.parent: + assert root_bone is None, "More than one root bone found." + root_bone = b + return root_bone + + +def get_armature(): + """ + Search the objects for an armature object. + """ + for obj in bpy.data.objects: + if obj.type == "ARMATURE": + return obj + return None + + +def construct_root_rotation_joint(root_node_name): + """ + Construct the root rotation joint XML Element. + """ + xml_root_joint = ET.Element("joint") + xml_root_joint.set("name", "root_rotation") + xml_root_joint.set("type", "fixed") + + # construct a standard rotation matrix transform to apply to all root nodes + M = Matrix.Rotation(math.radians(-90.0), 4, "X") + xml_root_joint.append(get_origin_from_matrix(M)) + + # create parent node + parent_xml_node = ET.Element("parent") + parent_xml_node.set("link", "root") + + # create child node + child_xml_node = ET.Element("child") + child_xml_node.set("link", root_node_name) + + xml_root_joint.append(parent_xml_node) + xml_root_joint.append(child_xml_node) + return xml_root_joint + + +def export( + dirpath, + settings, + export_urdf: bool = True, + export_meshes: bool = True, + export_ao_config: bool = True, + fix_materials: bool = True, + **kwargs, +): + """ + Run the Armature to URDF converter and export the .urdf file. + Recursively travserses the armature bone tree and constructs Links and Joints. + Note: This process is destructive and requires undo or revert in the editor after use. + + :return: export directory or URDF filepath + """ + + output_path = dirpath + + global LINK_NAME_FORMAT, JOINT_NAME_FORMAT, armature, root_bone, links, joints, counter + counter = 0 + links = [] + joints = [] + + # fixes a gltf export error caused by 1.0 ior values + if fix_materials: + for material in bpy.data.materials: + if material.node_tree is not None: + for node in material.node_tree.nodes: + if ( + node.type == "BSDF_PRINCIPLED" + and "IOR" in node.inputs + and node.inputs["IOR"].default_value == 1.000 + ): + node.inputs["IOR"].default_value = 0.000 + print(f"Changed IOR value for material '{material.name}'") + + bpy.context.view_layer.update() + + # check poll() to avoid exception. + if bpy.ops.object.mode_set.poll(): + bpy.ops.object.mode_set(mode="OBJECT") + + # get the armature + armature = settings.get("armature") + if armature is None: + armature = bpy.data.objects["Armature"] + + # find the root bone + root_bone = get_root_bone() + + if "link_name_format" in settings: + LINK_NAME_FORMAT = settings["link_name_format"] + + if "joint_name_format" in settings: + JOINT_NAME_FORMAT = settings["joint_name_format"] + + if "round_collision_scales" in settings and settings["round_collision_scales"]: + round_scales() + + if "fix_collision_scales" in settings and settings["fix_collision_scales"]: + fix_scales() + + # set the defaults to 100 T units and 3 units/sec (meters or radians) + effort, velocity = (100, 3) + if "def_limit_effort" in settings: + effort = settings["def_limit_effort"] + if "def_limit_vel" in settings: + velocity = settings["def_limit_vel"] + set_base_limit_str(effort, velocity) + + # clear the armature transform to remove unwanted transformations for later + clear_obj_transform(armature, apply=True) + + # print all mesh object parents, reset origins for mesh export and transformation registery, collect bone to mesh map + root_node = None + receptacle_meshes = [] + receptacle_to_link_name = {} + for obj in bpy.data.objects: + if obj.type == "MESH": + parent_bone = get_parent_bone(obj) + set_obj_origin_to_bone(obj, parent_bone) + print(f"MESH: {obj.name}") + if obj.parent is not None: + print(f" -p> {obj.parent.name}") + if obj.parent_bone != "": + bones_to_meshes[obj.parent_bone].append(obj) + print(f" -pb> {obj.parent_bone}") + if is_receptacle_mesh(obj): + receptacle_meshes.append(obj) + receptacle_to_link_name[obj.name] = obj.parent.name + elif obj.type == "EMPTY": + print(f"EMPTY: {obj.name}") + if obj.parent is None and len(obj.children) > 0: + print(" --IS ROOT") + root_node = obj + + # make export directory for the object + assert root_node is not None, "No root node, aborting." + final_out_path = os.path.join(dirpath, f"{root_node.name}") + os.makedirs(final_out_path, exist_ok=True) + print(f"Output path : {final_out_path}") + + # export mesh components + if export_meshes: + for mesh_list in bones_to_meshes.values(): + for mesh_obj in mesh_list: + if mesh_obj.parent is not None and mesh_obj.parent.type == "ARMATURE": + clear_obj_transform(mesh_obj) + deselect_all() + get_mesh_heirarchy(mesh_obj) + bpy.ops.export_scene.gltf( + filepath=os.path.join(final_out_path, mesh_obj.name), + use_selection=True, + export_yup=False, + ) + # export receptacle meshes + for rec_mesh in receptacle_meshes: + clear_obj_transform(rec_mesh.parent) + deselect_all() + rec_mesh.select_set(True) + bpy.ops.export_scene.gltf( + filepath=os.path.join(final_out_path, rec_mesh.name), + use_selection=True, + export_yup=False, + ) + + # print("------------------------") + # print("Bone info recursion:") + # walk_armature(root_bone, bone_info) + # print("------------------------") + # print("Node info recursion:") + # walk_armature(root_node, node_info) + # print("------------------------") + if export_urdf: + # Recursively generate the xml elements + walk_armature(root_bone, bone_to_urdf, kwargs_for_handler=kwargs) + + # add all the joints and links to the root + root_xml = ET.Element("robot") # create + root_xml.set("name", armature.name) + + # add a coordinate change in a dummy root node + xml_root_link = ET.Element("link") + xml_root_link.set("name", "root") + xml_root_joint = construct_root_rotation_joint(root_bone.name) + root_xml.append(xml_root_link) + root_xml.append(xml_root_joint) + + root_xml.append(ET.Comment("LINKS")) + for l in links: + root_xml.append(l) + + root_xml.append(ET.Comment("JOINTS")) + for j in joints: + root_xml.append(j) + + # dump the xml string + ET_raw_string = ET.tostring(root_xml, encoding="unicode") + dom = minidom.parseString(ET_raw_string) + ET_pretty_string = dom.toprettyxml() + + output_path = os.path.join(final_out_path, f"{root_node.name}.urdf") + + print(f"URDF output path : {output_path}") + with open(output_path, "w") as f: + f.write(ET_pretty_string) + + if export_ao_config: + # write the ao_config + ao_config_contents = { + "urdf_filepath": f"{root_node.name}.urdf", + "user_defined": { + # insert receptacle metadata here + }, + } + for rec_name, link_name in receptacle_to_link_name.items(): + rec_label = "receptacle_mesh_" + rec_name + ao_config_contents["user_defined"][rec_label] = { + "name": rec_name, + "parent_object": f"{root_node.name}", + "parent_link": link_name, + "position": [0, 0, 0], + "rotation": [1, 0, 0, 0], + "scale": [1, 1, 1], + "up": [0, 0, 1], + "mesh_filepath": rec_name + ".glb", + } + ao_config_filename = os.path.join( + final_out_path, f"{root_node.name}.ao_config.json" + ) + + print(f"ao config output path : {ao_config_filename}") + with open(ao_config_filename, "w") as f: + json.dump(ao_config_contents, f) + + return output_path + + +if __name__ == "__main__": + # NOTE: this must be run from within Blender and by default saves files in "blender_armatures/" relative to the directory containing the script + + export_path = None + try: + os.path.join( + os.path.dirname(bpy.context.space_data.text.filepath), "blender_armatures" + ) + except BaseException: + print( + "Couldn't get the directory from the filepath. E.g. running from commandline." + ) + + # ----------------------------------------------------------- + # To use this script in Blender editor: + # 1. run with export meshes True + # 2. undo changes in the editor + # 3. run with export meshes False + # 4. undo changes in the editor + + # NOTE: the following settings are overridden by commandline arguments if provided + + # Optionally override the save directory with an absolute path of your choice + # export_path = "/home/my_path_choice/" + + export_urdf = False + export_meshes = False + export_ao_config = False + round_collision_scales = False + fix_collision_scales = False + fix_materials = False + + # visual shape export flags for debugging + link_visuals = True + collision_visuals = False + joint_visuals = False + receptacle_visuals = False + # ----------------------------------------------------------- + + # ----------------------------------------------------------- + # To use from the commandline: + # 1. `blender .blend --background --python blender_armature_to_urdf.py -- --export-path + # 2. add `--export-meshes` to export the link .glbs + # Note: ' -- ' tells Blender to ignore the remaining arguemnts, so we pass anything after that into the script arguements below: + import sys + + argv = sys.argv + py_argv = "" + if "--" in argv: + py_argv = argv[argv.index("--") + 1 :] # get all args after "--" + + import argparse + + parser = argparse.ArgumentParser() + + parser.add_argument( + "--export-path", + default=export_path, + type=str, + help="Path to the output directory for meshes and URDF.", + ) + parser.add_argument( + "--export-meshes", + action="store_true", + default=export_meshes, + help="Export meshes for the link objects. If not set, instead generate the URDF.", + ) + parser.add_argument( + "--export-ao-config", + action="store_true", + default=export_ao_config, + help="Export a *.ao_config.json file for the URDF.", + ) + parser.add_argument( + "--export-urdf", + action="store_true", + default=export_urdf, + help="Export the *.urdf file.", + ) + # Debugging flags: + parser.add_argument( + "--no-link-visuals", + action="store_true", + default=not link_visuals, + help="Don't include visual mesh shapes in the exported URDF. E.g. for debugging.", + ) + parser.add_argument( + "--collision-visuals", + action="store_true", + default=collision_visuals, + help="Include visual shapes for collision primitives in the exported URDF. E.g. for debugging.", + ) + parser.add_argument( + "--joint-visuals", + action="store_true", + default=joint_visuals, + help="Include visual box shapes for joint pivot locations in the exported URDF. E.g. for debugging.", + ) + parser.add_argument( + "--receptacle-visuals", + action="store_true", + default=receptacle_visuals, + help="Include visual mesh shapes for receptacles in the exported URDF. E.g. for debugging.", + ) + parser.add_argument( + "--round-collision-scales", + action="store_true", + default=round_collision_scales, + help="Round all scale elements for collision shapes to 4 decimal points (millimeter accuracy).", + ) + parser.add_argument( + "--fix-collision-scales", + action="store_true", + default=fix_collision_scales, + help="Flip all negative scale elements for collision shapes.", + ) + parser.add_argument( + "--fix-materials", + action="store_true", + default=fix_materials, + help="Fixes materials with ior==1.0 which cause glTF export failure.", + ) + + args = parser.parse_args(py_argv) + export_urdf = args.export_urdf + export_meshes = args.export_meshes + export_ao_config = args.export_ao_config + export_path = args.export_path + + print( + f"export_urdf : {export_urdf} | export_meshes : {export_meshes} | export_ao_config : {export_ao_config} | export_path : {export_path}" + ) + + # ----------------------------------------------------------- + + assert ( + export_path is not None + ), "No export path provided. If running from commandline, provide a path with '--export-path ' after ' -- '." + + output_path = export( + export_path, + { + "armature": get_armature(), + "round_collision_scales": args.round_collision_scales, + "fix_collision_scales": args.fix_collision_scales, + #'def_limit_effort': 100, #custom default effort limit for joints + #'def_limit_vel': 3, #custom default vel limit for joints + }, + export_urdf=export_urdf, + export_meshes=export_meshes, + export_ao_config=export_ao_config, + fix_materials=args.fix_materials, + link_visuals=not args.no_link_visuals, + collision_visuals=args.collision_visuals, + joint_visuals=args.joint_visuals, + receptacle_visuals=args.receptacle_visuals, + ) + print(f"\n ======== Output saved to {output_path} ========\n") diff --git a/tools/check_siro_aos.py b/tools/check_siro_aos.py new file mode 100644 index 0000000000..bca4cfed0e --- /dev/null +++ b/tools/check_siro_aos.py @@ -0,0 +1,268 @@ +import os +from typing import Any, Dict, List + +# NOTE: (requires habitat-lab) get metadata for semantics +import magnum as mn +import numpy as np + +# NOTE: (requires habitat-llm) get metadata for semantics +from dataset_generation.benchmark_generation.generate_episodes import ( + MetadataInterface, + default_metadata_dict, + object_hash_from_handle, +) +from habitat.datasets.rearrange.samplers.receptacle import find_receptacles +from habitat.sims.habitat_simulator.debug_visualizer import DebugVisualizer + +from habitat_sim import Simulator +from habitat_sim.metadata import MetadataMediator +from habitat_sim.physics import ManagedArticulatedObject +from habitat_sim.utils.settings import default_sim_settings, make_cfg + +rand_colors = [mn.Color4(mn.Vector3(np.random.random(3))) for _ in range(100)] + + +def to_str_csv(data: Any) -> str: + """ + Format some data element as a string for csv such that it fits nicely into a cell. + """ + if data is None: + return "None" + if isinstance(data, str): + return data + if isinstance(data, (int, float)): + return str(data) + if isinstance(data, list): + list_str = "" + for elem in data: + list_str += f"{elem} |" + return list_str + + raise NotImplementedError(f"Data type {type(data)} is not supported in csv string.") + + +def get_labels_from_dict(results_dict: Dict[str, Dict[str, Any]]) -> List[str]: + """ + Get a list of column labels for the csv by scraping dict keys from the inner dict layers. + """ + labels = [] + for ao_dict in results_dict.values(): + for dict_key in ao_dict: + if dict_key not in labels: + labels.append(dict_key) + return labels + + +def export_results_csv(filepath: str, results_dict: Dict[str, Dict[str, Any]]) -> None: + assert filepath.endswith(".csv") + + col_labels = get_labels_from_dict(results_dict) + + with open(filepath, "w") as f: + # first write the column labels + f.write("ao,") + for c_label in col_labels: + f.write(f"{c_label},") + f.write("\n") + + # now a row for each scene + for ao_handle, ao_dict in results_dict.items(): + # write the ao handle column + f.write(f"{ao_handle},") + for label in col_labels: + if label in ao_dict: + f.write(f"{to_str_csv(ao_dict[label])},") + else: + f.write(",") + f.write("\n") + print(f"Wrote results csv to {filepath}") + + +def check_joint_popping( + sim: Simulator, out_dir: str = None, dbv: DebugVisualizer = None +) -> List[str]: + """ + Get a list of ao handles for objects which are not stable during simulation. + Checks the initial joint state, then simulates 1 second, then check the joint state again. Changes indicate popping, collisions, loose hinges, or other instability. + + :param out_dir: If provided, save debug images to the output directory prefixed "joint_pop____". + """ + + if out_dir is not None and dbv is None: + dbv = DebugVisualizer(sim) + + # record the ao handles + unstable_aos = [] + # record the sum of errors across all joints + cumulative_errors = [] + + ao_initial_joint_states = {} + + for ao_handle, ao in ( + sim.get_articulated_object_manager().get_objects_by_handle_substring().items() + ): + ao_initial_joint_states[ao_handle] = ao.joint_positions + + sim.step_physics(2.0) + + # cumulative error must be above this threshold to count as "unstable" + eps = 1e-3 + + for ao_handle, ao in ( + sim.get_articulated_object_manager().get_objects_by_handle_substring().items() + ): + jp = ao.joint_positions + if ao_initial_joint_states[ao_handle] != jp: + cumulative_error = sum( + [ + abs(ao_initial_joint_states[ao_handle][i] - jp[i]) + for i in range(len(jp)) + ] + ) + if cumulative_error > eps: + cumulative_errors.append(cumulative_error) + unstable_aos.append(ao_handle) + if out_dir is not None: + dbv.peek(ao_handle, peek_all_axis=True).save( + output_path=out_dir, prefix=f"joint_pop__{ao_handle}__" + ) + + return unstable_aos, cumulative_errors + + +def recompute_ao_bbs(ao: ManagedArticulatedObject) -> None: + """ + Recomputes the link SceneNode bounding boxes for all ao links. + NOTE: Gets around an observed loading bug. Call before trying to peek an AO. + """ + for link_ix in range(-1, ao.num_links): + link_node = ao.get_link_scene_node(link_ix) + link_node.compute_cumulative_bb() + + +if __name__ == "__main__": + import argparse + + parser = argparse.ArgumentParser() + + parser.add_argument( + "--dataset", + default="default", + type=str, + metavar="DATASET", + help='dataset configuration file to use (default: "default")', + ) + parser.add_argument( + "--out-dir", + default="siro_test_results/", + type=str, + help="directory in which to cache images and results csv.", + ) + parser.add_argument( + "--save-images", + default=False, + action="store_true", + help="save images during tests into the output directory.", + ) + args = parser.parse_args() + + os.makedirs(args.out_dir, exist_ok=True) + + # create an initial simulator config + sim_settings: Dict[str, Any] = default_sim_settings + sim_settings["scene_dataset_config_file"] = args.dataset + cfg = make_cfg(sim_settings) + + # pre-initialize a MetadataMediator to iterate over scenes + mm = MetadataMediator() + mm.active_dataset = args.dataset + cfg.metadata_mediator = mm + + mi = MetadataInterface(default_metadata_dict) + + # keyed by ao handle + ao_test_results: Dict[str, Dict[str, Any]] = {} + + ao_ix = 0 + # split up the load per-simulator reconfigure to balance memory overhead with init time + iters_per_sim = 50 + ao_handles = mm.ao_template_manager.get_template_handles() + while ao_ix < len(ao_handles): + with Simulator(cfg) as sim: + dbv = DebugVisualizer(sim) + aom = sim.get_articulated_object_manager() + + for _i in range(iters_per_sim): + if ao_ix >= len(ao_handles): + # early escape if done + break + + ao_handle = ao_handles[ao_ix] + ao_short_handle = ao_handle.split("/")[-1].split(".")[0] + ao_ix += 1 + ao_test_results[ao_short_handle] = {} + asset_failure_message = None + ao = None + + # first try to load the asset + try: + ao = aom.add_articulated_object_by_template_handle(ao_handle) + except Exception as e: + print(f"Failed to load asset {ao_handle}. '{repr(e)}'") + asset_failure_message = repr(e) + + if ao is None: + # load failed, record the message and continue + ao_test_results[ao_short_handle]["failure_log"] = to_str_csv( + asset_failure_message + ) + continue + + # check joint popping + unstable_aos, joint_errors = check_joint_popping( + sim, out_dir=args.out_dir if args.save_images else None, dbv=dbv + ) + + if len(unstable_aos) > 0: + ao_test_results[ao_short_handle][ + "joint_popping_error" + ] = joint_errors[0] + + ########################################### + # produce a gif of actuation + # TODO: + + ########################################### + # load the receptacles + try: + recs = find_receptacles(sim) + except Exception as e: + print(f"Failed to load receptacles for {ao_handle}. '{repr(e)}'") + asset_failure_message = repr(e) + ao_test_results[ao_short_handle]["failure_log"] = to_str_csv( + asset_failure_message + ) + + ########################################### + # snap an image and sort into category subfolder + recompute_ao_bbs(ao) + hash_name = object_hash_from_handle(ao_handle) + cat = mi.get_object_category(hash_name) + if cat is None: + cat = "None" + + ao_peek = dbv.peek(ao.handle, peek_all_axis=True) + cat_dir = os.path.join(args.out_dir, f"ao_categories/{cat}/") + os.makedirs(cat_dir, exist_ok=True) + ao_peek.save(cat_dir, prefix=hash_name + "__") + + ############################################# + # DONE: clear the scene for next iteration + aom.remove_all_objects() + + # check if done with last ao + if ao_ix >= len(ao_handles): + break + + csv_filepath = os.path.join(args.out_dir, "siro_ao_test_results.csv") + export_results_csv(csv_filepath, ao_test_results) diff --git a/tools/check_siro_scenes.py b/tools/check_siro_scenes.py new file mode 100644 index 0000000000..050cc30a92 --- /dev/null +++ b/tools/check_siro_scenes.py @@ -0,0 +1,765 @@ +import json +import os +from collections import defaultdict +from typing import Any, Dict, List, Optional, Tuple + +import habitat.datasets.rearrange.samplers.receptacle as hab_receptacle +import habitat.sims.habitat_simulator.sim_utilities as sutils +import magnum as mn +import numpy as np + +# NOTE: (requires habitat-llm) get metadata for semantics +from dataset_generation.benchmark_generation.generate_episodes import ( + MetadataInterface, + default_metadata_dict, +) +from habitat.datasets.rearrange.navmesh_utils import ( + get_largest_island_index, + unoccluded_navmesh_snap, +) +from habitat.datasets.rearrange.samplers.object_sampler import ObjectSampler +from habitat.sims.habitat_simulator.debug_visualizer import DebugVisualizer + +# NOTE: (requires habitat-lab) get metadata for semantics +import habitat_sim +from habitat_sim import NavMeshSettings, Simulator +from habitat_sim.metadata import MetadataMediator +from habitat_sim.physics import ManagedArticulatedObject +from habitat_sim.utils.settings import default_sim_settings, make_cfg + +rand_colors = [mn.Color4(mn.Vector3(np.random.random(3))) for _ in range(100)] + + +def to_str_csv(data: Any) -> str: + """ + Format some data element as a string for csv such that it fits nicely into a cell. + """ + + if isinstance(data, str): + return data + if isinstance(data, (int, float)): + return str(data) + if isinstance(data, list): + list_str = "" + for elem in data: + list_str += f"{elem};" + return list_str + + raise NotImplementedError(f"Data type {type(data)} is not supported in csv string.") + + +def get_labels_from_dict(results_dict: Dict[str, Dict[str, Any]]) -> List[str]: + """ + Get a list of column labels for the csv by scraping dict keys from the inner dict layers. + """ + labels = [] + for scene_dict in results_dict.values(): + for dict_key in scene_dict: + if dict_key not in labels: + labels.append(dict_key) + return labels + + +def export_results_csv(filepath: str, results_dict: Dict[str, Dict[str, Any]]) -> None: + assert filepath.endswith(".csv") + + col_labels = get_labels_from_dict(results_dict) + + with open(filepath, "w") as f: + # first write the column labels + f.write("scene,") + for c_label in col_labels: + f.write(f"{c_label},") + f.write("\n") + + # now a row for each scene + for scene_handle, scene_dict in results_dict.items(): + # write the scene column + f.write(f"{scene_handle},") + for label in col_labels: + if label in scene_dict: + f.write(f"{to_str_csv(scene_dict[label])},") + else: + f.write(",") + f.write("\n") + print(f"Wrote results csv to {filepath}") + + +def check_joint_popping( + sim: Simulator, out_dir: str = None, dbv: DebugVisualizer = None +) -> List[str]: + """ + Get a list of ao handles for objects which are not stable during simulation. + Checks the initial joint state, then simulates 1 second, then check the joint state again. Changes indicate popping, collisions, loose hinges, or other instability. + + :param out_dir: If provided, save debug images to the output directory prefixed "joint_pop____". + """ + + if out_dir is not None and dbv is None: + dbv = DebugVisualizer(sim) + + # record the ao handles + unstable_aos = [] + # record the sum of errors across all joints + cumulative_errors = [] + + ao_initial_joint_states = {} + + for ao_handle, ao in ( + sim.get_articulated_object_manager().get_objects_by_handle_substring().items() + ): + ao_initial_joint_states[ao_handle] = ao.joint_positions + + sim.step_physics(2.0) + + # cumulative error must be above this threshold to count as "unstable" + eps = 1e-3 + + for ao_handle, ao in ( + sim.get_articulated_object_manager().get_objects_by_handle_substring().items() + ): + jp = ao.joint_positions + if ao_initial_joint_states[ao_handle] != jp: + cumulative_error = sum( + [ + abs(ao_initial_joint_states[ao_handle][i] - jp[i]) + for i in range(len(jp)) + ] + ) + if cumulative_error > eps: + cumulative_errors.append(cumulative_error) + unstable_aos.append(ao_handle) + if out_dir is not None: + dbv.peek(ao_handle, peek_all_axis=True).save( + output_path=out_dir, prefix=f"joint_pop__{ao_handle}__" + ) + + return unstable_aos, cumulative_errors + + +def draw_region_debug(sim: Simulator, region_ix: int) -> None: + """ + Draw a wireframe for the semantic region at index region_ix. + """ + region = sim.semantic_scene.regions[region_ix] + color = rand_colors[region_ix] + for edge in region.volume_edges: + sim.get_debug_line_render().draw_transformed_line( + edge[0], + edge[1], + color, + ) + + +def draw_all_regions_debug(sim: Simulator) -> None: + for reg_ix in range(len(sim.semantic_scene.regions)): + draw_region_debug(sim, reg_ix) + + +def save_region_visualizations( + sim: Simulator, out_dir: str, dbv: DebugVisualizer +) -> None: + """ + Save top-down images focused on each region with debug lines. + """ + + os.makedirs(out_dir, exist_ok=True) + + draw_all_regions_debug(sim) + dbv.peek("stage").save(output_path=os.path.join(out_dir), prefix="all_regions_") + + for rix, region in enumerate(sim.semantic_scene.regions): + normalized_region_id = region.id.replace("/", "|").replace(" ", "_") + draw_region_debug(sim, rix) + aabb = mn.Range3D.from_center(region.aabb.center, region.aabb.sizes / 2.0) + reg_obs = dbv._peek_bb(aabb, cam_local_pos=mn.Vector3(0, 1, 0)) + reg_obs.save( + output_path=os.path.join(out_dir), prefix=f"{normalized_region_id}_" + ) + + +def get_region_counts(sim: Simulator) -> Dict[str, int]: + """ + Count all the region categories in the active scene. + """ + + region_counts = defaultdict(lambda: 0) + for region in sim.semantic_scene.regions: + region_counts[region.category.name()] += 1 + return region_counts + + +def save_region_counts_csv(region_counts: Dict[str, int], filepath: str) -> None: + """ + Save the region counts to a csv file. + """ + + assert filepath.endswith(".csv") + + with open(filepath, "w") as f: + f.write("region_name, count\n") + for region_name, count in region_counts.items(): + f.write(f"{region_name}, {count}, \n") + + print(f"Wrote region counts csv to {filepath}") + + +def check_rec_accessibility( + sim, + rec: hab_receptacle.Receptacle, + clutter_object_handles: List[str], + max_height: float = 1.2, + clean_up=True, + island_index: int = -1, +) -> Tuple[bool, str]: + """ + Use unoccluded navmesh snap to check whether a Receptacle is accessible. + """ + + assert len(clutter_object_handles) > 0 + + print(f"Checking Receptacle accessibility for {rec.unique_name}") + + # first check if the receptacle is close enough to the navmesh + rec_global_keypoints = sutils.get_global_keypoints_from_bb( + rec.bounds, rec.get_global_transform(sim) + ) + floor_point = None + for keypoint in rec_global_keypoints: + floor_point = sim.pathfinder.snap_point(keypoint, island_index=island_index) + if not np.isnan(floor_point[0]): + break + if np.isnan(floor_point[0]): + print(" - Receptacle too far from active navmesh boundary.") + return False, "access_filtered" + + # then check that the height is acceptable + rec_min = min(rec_global_keypoints, key=lambda x: x[1]) + if rec_min[1] - floor_point[1] > max_height: + print( + f" - Receptacle exceeds maximum height {rec_min[1]-floor_point[1]} vs {max_height}." + ) + return False, "height_filtered" + + # try to sample 10 objects on the receptacle + target_number = 10 + obj_samp = ObjectSampler( + clutter_object_handles, + ["rec set"], + orientation_sample="up", + num_objects=(1, target_number), + ) + obj_samp.max_sample_attempts = len(clutter_object_handles) + obj_samp.max_placement_attempts = 10 + obj_samp.target_objects_number = target_number + rec_set_unique_names = [rec.unique_name] + rec_set_obj = hab_receptacle.ReceptacleSet( + "rec set", [""], [], rec_set_unique_names, [] + ) + recep_tracker = hab_receptacle.ReceptacleTracker( + {}, + {"rec set": rec_set_obj}, + ) + + new_objs = [] + try: + new_objs = obj_samp.sample(sim, recep_tracker, [], snap_down=True) + except Exception as e: + print(f" - generation failed with internal exception {repr(e)}") + + # if we can't sample objects, this receptacle is out + if len(new_objs) == 0: + print(" - failed to sample any objects.") + return False, "access_filtered" + print(f" - sampled {len(new_objs)} / {target_number} objects.") + + # now try unoccluded navmesh snapping to the objects to test accessibility + obj_positions = [obj.translation for obj, _ in new_objs] + for obj, _ in new_objs: + obj.translation += mn.Vector3(100, 0, 0) + failure_count = 0 + + for o_ix, (obj, _) in enumerate(new_objs): + obj.translation = obj_positions[o_ix] + snap_point = unoccluded_navmesh_snap( + obj.translation, 1.3, sim.pathfinder, sim, obj.object_id, island_index + ) + # self.dbv.look_at(look_at=obj.translation, look_from=snap_point) + # self.dbv.get_observation().show() + if snap_point is None: + failure_count += 1 + obj.translation += mn.Vector3(100, 0, 0) + for o_ix, (obj, _) in enumerate(new_objs): + obj.translation = obj_positions[o_ix] + failure_rate = (float(failure_count) / len(new_objs)) * 100 + print(f" - failure_rate = {failure_rate}") + print( + f" - accessibility rate = {len(new_objs)-failure_count}|{len(new_objs)} ({100-failure_rate}%)" + ) + + accessible = failure_rate < 20 # 80% accessibility required + + if clean_up: + # removing all clutter objects currently + rom = sim.get_rigid_object_manager() + for obj, _ in new_objs: + rom.remove_object_by_handle(obj.handle) + + if not accessible: + return False, "access_filtered" + + return True, "active" + + +def init_rec_filter_data_dict() -> Dict[str, Any]: + """ + Get an empty rec_filter_data dictionary. + """ + return { + "active": [], + "manually_filtered": [], + "access_filtered": [], + "access_threshold": -1, # set in filter procedure + "stability_filtered": [], + "stability threshold": -1, # set in filter procedure + "height_filtered": [], + "max_height": 1.2, + "min_height": 0, + } + + +def write_rec_filter_json(filepath: str, json_dict: Dict[str, Any]) -> None: + """ + Write the receptacle filter json dict. + """ + + assert filepath.endswith(".json") + os.makedirs(os.path.dirname(filepath), exist_ok=True) + with open(filepath, "w") as f: + f.write(json.dumps(json_dict, indent=2)) + + +def set_filter_status_for_rec( + rec: hab_receptacle.Receptacle, + filter_status: str, + rec_filter_data: Dict[str, Any], + ignore_existing_status: Optional[List[str]] = None, +) -> None: + """ + Set the filter status of a Receptacle in the filter dictionary. + + :param rec: The Receptacle instance. + :param filter_status: The status to assign. + :param rec_filter_data: The current filter dictionary to modify. + :param ignore_existing_status: An optional list of filter types to lock, preventing re-assignment. + """ + + if ignore_existing_status is None: + ignore_existing_status = [] + filter_types = [ + "access_filtered", + "stability_filtered", + "height_filtered", + "manually_filtered", + "active", + ] + assert filter_status in filter_types + filtered_rec_name = rec.unique_name + for filter_type in filter_types: + if filtered_rec_name in rec_filter_data[filter_type]: + if filter_type in ignore_existing_status: + print( + f"Trying to assign filter status {filter_status} but existing status {filter_type} in ignore list. Aborting assignment." + ) + return + else: + rec_filter_data[filter_type].remove(filtered_rec_name) + rec_filter_data[filter_status].append(filtered_rec_name) + + +def navmesh_config_and_recompute(sim) -> None: + """ + Re-compute the navmesh with specific settings. + """ + + navmesh_settings = NavMeshSettings() + navmesh_settings.set_defaults() + navmesh_settings.agent_height = 1.3 # spot + navmesh_settings.agent_radius = 0.3 # human || spot + navmesh_settings.include_static_objects = True + sim.recompute_navmesh( + sim.pathfinder, + navmesh_settings, + ) + + +def initialize_clutter_object_set(sim) -> None: + """ + Get the template handles for configured clutter objects. + """ + clutter_object_set = [ + "002_master_chef_can", + "003_cracker_box", + "004_sugar_box", + "005_tomato_soup_can", + "007_tuna_fish_can", + "008_pudding_box", + "009_gelatin_box", + "010_potted_meat_can", + "024_bowl", + ] + clutter_object_handles = [] + for obj_name in clutter_object_set: + matching_handles = ( + sim.metadata_mediator.object_template_manager.get_template_handles(obj_name) + ) + assert ( + len(matching_handles) > 0 + ), f"No matching template for '{obj_name}' in the dataset." + clutter_object_handles.append(matching_handles[0]) + return clutter_object_handles + + +def run_rec_filter_analysis( + sim, out_dir: str, open_default_links: bool = True, keep_manual_filters: bool = True +) -> None: + """ + Collect all receptacles for the scene and run an accessibility check, saving the resulting filter file. + + :param out_dir: Where to write the filter files. + :param open_default_links: Whether or not to open default links when considering final accessible Receptacles set. + :param keep_manual_filters: Whether to keep or override existing manual filter definitions. + """ + + rec_filter_dict = init_rec_filter_data_dict() + + # load the clutter objects + sim.metadata_mediator.object_template_manager.load_configs( + "data/objects/ycb/configs/" + ) + clutter_object_handles = initialize_clutter_object_set(sim) + + # recompute the navmesh with expect parameters + navmesh_config_and_recompute(sim) + + # get the largest indoor island + largest_island = get_largest_island_index(sim.pathfinder, sim, allow_outdoor=False) + + # keep manually filtered receptacles + ignore_existing_status = [] + if keep_manual_filters: + existing_scene_filter_file = hab_receptacle.get_scene_rec_filter_filepath( + sim.metadata_mediator, sim.curr_scene_name + ) + if existing_scene_filter_file is not None: + filter_strings = hab_receptacle.get_excluded_recs_from_filter_file( + existing_scene_filter_file, filter_types=["manually_filtered"] + ) + rec_filter_dict["manually_filtered"] = filter_strings + ignore_existing_status.append("manually_filtered") + + recs = hab_receptacle.find_receptacles( + sim, exclude_filter_strings=rec_filter_dict["manually_filtered"] + ) + + # compute a map from parent object to Receptacles + parent_handle_to_rec: Dict[str, List[hab_receptacle.Receptacle]] = defaultdict( + lambda: [] + ) + for rec in recs: + parent_handle_to_rec[rec.parent_object_handle].append(rec) + + # compute the default accessibility with all closed links + default_active_set: List[hab_receptacle.Receptacle] = [] + for rix, rec in enumerate(recs): + rec_accessible, filter_type = check_rec_accessibility( + sim, rec, clutter_object_handles, island_index=largest_island + ) + if rec_accessible: + default_active_set.append(rec) + set_filter_status_for_rec( + rec, + filter_type, + rec_filter_dict, + ignore_existing_status=ignore_existing_status, + ) + print(f"-- progress = {rix}/{len(recs)} --") + + # open default links and re-compute accessibility for each AO + # the difference between default state accessibility and open state accessibility defines the "within_set" + within_set: List[hab_receptacle.Receptacle] = [] + if open_default_links: + all_objects = sutils.get_all_objects(sim) + aos = [obj for obj in all_objects if isinstance(obj, ManagedArticulatedObject)] + for aoix, ao in enumerate(aos): + default_link = sutils.get_ao_default_link(ao, True) + if default_link is not None: + sutils.open_link(ao, default_link) + # recompute accessibility + for child_rec in parent_handle_to_rec[ao.handle]: + rec_accessible, filter_type = check_rec_accessibility( + sim, + child_rec, + clutter_object_handles, + island_index=largest_island, + ) + if rec_accessible and child_rec not in default_active_set: + # found a Receptacle which is only accessible when the default_link is open + within_set.append(child_rec) + set_filter_status_for_rec( + child_rec, + filter_type, + rec_filter_dict, + ignore_existing_status=ignore_existing_status, + ) + sutils.close_link(ao, default_link) + print(f"-- progress = {aoix}/{len(aos)} --") + + # write the within set to the filter file + rec_filter_dict["within_set"] = [ + within_rec.unique_name for within_rec in within_set + ] + + # write the filter file to JSON + filter_filepath = os.path.join( + out_dir, f"scene_filter_files/{sim.curr_scene_name}.rec_filter.json" + ) + write_rec_filter_json(filter_filepath, rec_filter_dict) + + +def get_global_faucet_points(sim: habitat_sim.Simulator) -> Dict[str, List[mn.Vector3]]: + """ + Gets a global set of points identifying faucets for each object in the scene. + Returns a dict mapping object handles to global points. + """ + objs = sutils.get_all_objects(sim) + obj_markersets = {} + for obj in objs: + all_obj_marker_sets = obj.marker_sets + if all_obj_marker_sets.has_taskset("faucets"): + # this object has faucet annotations + obj_markersets[obj.handle] = [] + faucet_marker_sets = all_obj_marker_sets.get_taskset_points("faucets") + for link_name, link_faucet_markers in faucet_marker_sets.items(): + link_id = -1 + if link_name != "root": + link_id = obj.get_link_id_from_name(link_name) + for _marker_subset_name, points in link_faucet_markers.items(): + global_points = obj.transform_local_pts_to_world(points, link_id) + obj_markersets[obj.handle].extend(global_points) + return obj_markersets + + +def draw_global_point_set(global_points: List[mn.Vector3], debug_line_render): + for point in global_points: + debug_line_render.draw_circle( + translation=point, + radius=0.02, + normal=mn.Vector3(0, 1, 0), + color=mn.Color4.red(), + num_segments=12, + ) + + +if __name__ == "__main__": + import argparse + + parser = argparse.ArgumentParser() + + parser.add_argument( + "--dataset", + default="default", + type=str, + metavar="DATASET", + help='dataset configuration file to use (default: "default")', + ) + parser.add_argument( + "--out-dir", + default="siro_test_results/", + type=str, + help="directory in which to cache images and results csv.", + ) + parser.add_argument( + "--save-images", + default=False, + action="store_true", + help="save images during tests into the output directory.", + ) + parser.add_argument( + "--actions", + nargs="+", + type=str, + help="A set of strings indicating check actions to be performed on the dataset.", + default=None, + ) + args = parser.parse_args() + + available_check_actions = [ + "faucet_points", + "rec_unique_names", + "rec_filters", + "region_counts", + "joint_popping", + "visualize_regions", + "analyze_semantics", + ] + + target_check_actions = [] + + assert args.actions is not None, "Must select and action." + + for target_action in args.actions: + assert ( + target_action in available_check_actions + ), f"provided action {target_action} is not in the valid set: {available_check_actions}" + + target_check_actions = args.actions + + os.makedirs(args.out_dir, exist_ok=True) + + # create an initial simulator config + sim_settings: Dict[str, Any] = default_sim_settings + sim_settings["scene_dataset_config_file"] = args.dataset + cfg = make_cfg(sim_settings) + + # pre-initialize a MetadataMediator to iterate over scenes + mm = MetadataMediator() + mm.active_dataset = args.dataset + cfg.metadata_mediator = mm + + mi = MetadataInterface(default_metadata_dict) + + # keyed by scene handle + scene_test_results: Dict[str, Dict[str, Any]] = {} + + # count all region category names in all scenes + region_counts: Dict[str, int] = defaultdict(lambda: 0) + + num_scenes = len(mm.get_scene_handles()) + + # for each scene, initialize a fresh simulator and run tests + for s_ix, scene_handle in enumerate(mm.get_scene_handles()): + print("=================================================================") + print( + f"Setting up scene for {scene_handle} ({s_ix}|{num_scenes} = {s_ix/float(num_scenes)*100}%)" + ) + cfg.sim_cfg.scene_id = scene_handle + print(" - init") + with Simulator(cfg) as sim: + dbv = DebugVisualizer(sim) + + mi.refresh_scene_caches(sim) + + scene_test_results[sim.curr_scene_name] = {} + scene_test_results[sim.curr_scene_name][ + "ros" + ] = sim.get_rigid_object_manager().get_num_objects() + scene_test_results[sim.curr_scene_name][ + "aos" + ] = sim.get_articulated_object_manager().get_num_objects() + + scene_out_dir = os.path.join(args.out_dir, f"{sim.curr_scene_name}/") + + ########################################## + # get images of all the global faucet points in the scene + if "faucet_points" in target_check_actions: + scene_obj_global_faucet_points = get_global_faucet_points(sim) + for obj_handle, global_points in scene_obj_global_faucet_points.items(): + circles = [ + (point, 0.1, mn.Vector3(0, 1, 0), mn.Color4.red()) + for point in global_points + ] + dbv.peek( + obj_handle, peek_all_axis=True, debug_circles=circles + ).save(scene_out_dir, f"faucets_{obj_handle}_") + + ########################################## + # gather all Receptacle.unique_name in the scene + if "rec_unique_names" in target_check_actions: + all_recs = hab_receptacle.find_receptacles(sim) + unique_names = [rec.unique_name for rec in all_recs] + scene_test_results[sim.curr_scene_name][ + "rec_unique_names" + ] = unique_names + + ########################################## + # receptacle filter computation + if "rec_filters" in target_check_actions: + run_rec_filter_analysis( + sim, args.out_dir, open_default_links=True, keep_manual_filters=True + ) + + ########################################## + # Check region counts + if "region_counts" in target_check_actions: + print(" - region counts") + scene_region_counts = get_region_counts(sim) + for region_name, count in scene_region_counts.items(): + region_counts[region_name] += count + + ########################################## + # Check for joint popping + if "joint_popping" in target_check_actions: + print(" - check joint popping") + unstable_aos, joint_errors = check_joint_popping( + sim, out_dir=scene_out_dir if args.save_images else None, dbv=dbv + ) + if len(unstable_aos) > 0: + scene_test_results[sim.curr_scene_name]["unstable_aos"] = "" + for ix, ao_handle in enumerate(unstable_aos): + scene_test_results[sim.curr_scene_name][ + "unstable_aos" + ] += f"{ao_handle}({joint_errors[ix]}) | " + + ############################################ + # analyze and visualize regions + if "visualize_regions" in target_check_actions: + print(" - check and visualize regions") + if args.save_images: + save_region_visualizations( + sim, os.path.join(scene_out_dir, "regions/"), dbv + ) + expected_regions = ["kitchen", "living room", "bedroom"] + all_region_cats = [ + region.category.name() for region in sim.semantic_scene.regions + ] + missing_expected_regions = [ + expected_region + for expected_region in expected_regions + if expected_region not in all_region_cats + ] + if len(missing_expected_regions) > 0: + scene_test_results[sim.curr_scene_name][ + "missing_expected_regions" + ] = "" + for expected_region in missing_expected_regions: + scene_test_results[sim.curr_scene_name][ + "missing_expected_regions" + ] += f"{expected_region} | " + + ############################################## + # analyze semantics + if "analyze_semantics" in target_check_actions: + print(" - check and visualize semantics") + scene_test_results[sim.curr_scene_name][ + "objects_missing_semantic_class" + ] = [] + missing_semantics_output = os.path.join( + scene_out_dir, "missing_semantics/" + ) + for obj in sutils.get_all_objects(sim): + if mi.get_object_instance_category(obj) is None: + scene_test_results[sim.curr_scene_name][ + "objects_missing_semantic_class" + ].append(obj.handle) + if args.save_images: + os.makedirs(missing_semantics_output, exist_ok=True) + dbv.peek(obj, peek_all_axis=True).save( + missing_semantics_output, f"{obj.handle}__" + ) + + csv_filepath = os.path.join(args.out_dir, "siro_scene_test_results.csv") + export_results_csv(csv_filepath, scene_test_results) + if "region_counts" in target_check_actions: + region_count_csv_filepath = os.path.join(args.out_dir, "region_counts.csv") + save_region_counts_csv(region_counts, region_count_csv_filepath) diff --git a/tools/collision_shape_automation.py b/tools/collision_shape_automation.py new file mode 100644 index 0000000000..d279649fbe --- /dev/null +++ b/tools/collision_shape_automation.py @@ -0,0 +1,2633 @@ +# Copyright (c) Meta Platforms, Inc. and its affiliates. +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import argparse +import csv +import ctypes +import math +import os +import random +import sys +import time +from typing import Any, Dict, List, Optional, Tuple + +coacd_imported = False +try: + import coacd + import trimesh + + coacd_imported = True +except Exception: + coacd_imported = False + print("Failed to import coacd, is it installed? Linux only: 'pip install coacd'") + +# not adding this causes some failures in mesh import +flags = sys.getdlopenflags() +sys.setdlopenflags(flags | ctypes.RTLD_GLOBAL) + + +# imports from Habitat-lab +# NOTE: requires PR 1108 branch: rearrange-gen-improvements (https://github.com/facebookresearch/habitat-lab/pull/1108) +import habitat.datasets.rearrange.samplers.receptacle as hab_receptacle +import habitat.sims.habitat_simulator.debug_visualizer as hab_debug_vis +import magnum as mn +import numpy as np +from habitat.sims.habitat_simulator.sim_utilities import snap_down + +import habitat_sim +from habitat_sim.utils.settings import default_sim_settings, make_cfg + +# object samples: +# chair - good approximation: 0a5e809804911e71de6a4ef89f2c8fef5b9291b3.glb +# shelves - bad approximation: d1d1e0cdaba797ee70882e63f66055675c3f1e7f.glb + +# 71 equidistant points on a unit hemisphere generated from icosphere subdivision +# Sphere center is (0,0,0) and no points lie on x,z plane +# used for hemisphere raycasting from Receptacle points +icosphere_points_subdiv_3 = [ + mn.Vector3(-0.276388, 0.447220, -0.850649), + mn.Vector3(-0.483971, 0.502302, -0.716565), + mn.Vector3(-0.232822, 0.657519, -0.716563), + mn.Vector3(0.723607, 0.447220, -0.525725), + mn.Vector3(0.531941, 0.502302, -0.681712), + mn.Vector3(0.609547, 0.657519, -0.442856), + mn.Vector3(0.723607, 0.447220, 0.525725), + mn.Vector3(0.812729, 0.502301, 0.295238), + mn.Vector3(0.609547, 0.657519, 0.442856), + mn.Vector3(-0.276388, 0.447220, 0.850649), + mn.Vector3(-0.029639, 0.502302, 0.864184), + mn.Vector3(-0.232822, 0.657519, 0.716563), + mn.Vector3(-0.894426, 0.447216, 0.000000), + mn.Vector3(-0.831051, 0.502299, 0.238853), + mn.Vector3(-0.753442, 0.657515, 0.000000), + mn.Vector3(-0.251147, 0.967949, 0.000000), + mn.Vector3(-0.077607, 0.967950, 0.238853), + mn.Vector3(0.000000, 1.000000, 0.000000), + mn.Vector3(-0.525730, 0.850652, 0.000000), + mn.Vector3(-0.361800, 0.894429, 0.262863), + mn.Vector3(-0.638194, 0.723610, 0.262864), + mn.Vector3(-0.162456, 0.850654, 0.499995), + mn.Vector3(-0.447209, 0.723612, 0.525728), + mn.Vector3(-0.688189, 0.525736, 0.499997), + mn.Vector3(-0.483971, 0.502302, 0.716565), + mn.Vector3(0.203181, 0.967950, 0.147618), + mn.Vector3(0.138197, 0.894430, 0.425319), + mn.Vector3(0.052790, 0.723612, 0.688185), + mn.Vector3(0.425323, 0.850654, 0.309011), + mn.Vector3(0.361804, 0.723612, 0.587778), + mn.Vector3(0.262869, 0.525738, 0.809012), + mn.Vector3(0.531941, 0.502302, 0.681712), + mn.Vector3(0.203181, 0.967950, -0.147618), + mn.Vector3(0.447210, 0.894429, 0.000000), + mn.Vector3(0.670817, 0.723611, 0.162457), + mn.Vector3(0.425323, 0.850654, -0.309011), + mn.Vector3(0.670817, 0.723611, -0.162457), + mn.Vector3(0.850648, 0.525736, 0.000000), + mn.Vector3(0.812729, 0.502301, -0.295238), + mn.Vector3(-0.077607, 0.967950, -0.238853), + mn.Vector3(0.138197, 0.894430, -0.425319), + mn.Vector3(0.361804, 0.723612, -0.587778), + mn.Vector3(-0.162456, 0.850654, -0.499995), + mn.Vector3(0.052790, 0.723612, -0.688185), + mn.Vector3(0.262869, 0.525738, -0.809012), + mn.Vector3(-0.029639, 0.502302, -0.864184), + mn.Vector3(-0.361800, 0.894429, -0.262863), + mn.Vector3(-0.447209, 0.723612, -0.525728), + mn.Vector3(-0.638194, 0.723610, -0.262864), + mn.Vector3(-0.688189, 0.525736, -0.499997), + mn.Vector3(-0.831051, 0.502299, -0.238853), + mn.Vector3(-0.956626, 0.251149, 0.147618), + mn.Vector3(-0.861804, 0.276396, 0.425322), + mn.Vector3(-0.670821, 0.276397, 0.688189), + mn.Vector3(-0.436007, 0.251152, 0.864188), + mn.Vector3(-0.155215, 0.251152, 0.955422), + mn.Vector3(0.138199, 0.276397, 0.951055), + mn.Vector3(0.447215, 0.276397, 0.850649), + mn.Vector3(0.687159, 0.251152, 0.681715), + mn.Vector3(0.860698, 0.251151, 0.442858), + mn.Vector3(0.947213, 0.276396, 0.162458), + mn.Vector3(0.947213, 0.276397, -0.162458), + mn.Vector3(0.860698, 0.251151, -0.442858), + mn.Vector3(0.687159, 0.251152, -0.681715), + mn.Vector3(0.447216, 0.276397, -0.850648), + mn.Vector3(0.138199, 0.276397, -0.951055), + mn.Vector3(-0.155215, 0.251152, -0.955422), + mn.Vector3(-0.436007, 0.251152, -0.864188), + mn.Vector3(-0.670820, 0.276396, -0.688190), + mn.Vector3(-0.861804, 0.276394, -0.425323), + mn.Vector3(-0.956626, 0.251149, -0.147618), +] + + +def get_scaled_hemisphere_vectors(scale: float): + """ + Scales the icosphere_points for use with raycasting applications. + """ + return [v * scale for v in icosphere_points_subdiv_3] + + +class COACDParams: + def __init__( + self, + ) -> None: + # Parameter tuning tricks from https://github.com/SarahWeiii/CoACD: + + # The default parameters are fast versions. If you care less about running time but more about the number of components, try to increase searching depth, searching node, and searching iteration for better cutting strategies. + self.threshold = 0.05 # adjust the threshold (0.01~1) to balance the level of detail and the number of decomposed components. A higher value gives coarser results, and a lower value gives finer-grained results. You can refer to Fig. 14 in our paper for more details. + self.max_convex_hull = -1 + self.preprocess = True # ensure input mesh is 2-manifold solid if you want to skip pre-process. Skipping manifold pre-processing can better preserve input details, but can crash or fail otherwise if input is not manifold. + self.preprocess_resolution = 30 # controls the quality of manifold preprocessing. A larger value can make the preprocessed mesh closer to the original mesh but also lead to more triangles and longer runtime. + self.mcts_nodes = 20 + self.mcts_iterations = 150 + self.mcts_max_depth = 3 + self.pca = False + self.merge = True + self.seed = 0 + + def __str__(self) -> str: + return f"COACDParams(threshold={self.threshold} | max_convex_hull={self.max_convex_hull} | preprocess={self.preprocess} | preprocess_resolution={self.preprocess_resolution} | mcts_nodes={self.mcts_nodes} | mcts_iterations={self.mcts_iterations} | mcts_max_depth={self.mcts_max_depth} | pca={self.pca} | merge={self.merge} | seed={self.seed})" + + +def print_dict_structure(input_dict: Dict[Any, Any], whitespace: str = "") -> None: + """ + Quick structure investigation for dictionary. + Prints dict key->type recursively with incremental whitespace formatting. + """ + if whitespace == "": + print("-----------------------------------") + print("Print Dict Structure Results:") + for key in input_dict: + if isinstance(input_dict[key], Dict): + print(whitespace + f"{key}:-") + print_dict_structure( + input_dict=input_dict[key], whitespace=whitespace + " " + ) + else: + print(whitespace + f"{key}: {type(input_dict[key])}") + if whitespace == "": + print("-----------------------------------") + + +# ======================================================================= +# Range3D surface sampling utils + + +def compute_area_weights_for_range3d_faces(range3d: mn.Range3D): + """ + Compute a set of area weights from a Range3D. + """ + + face_areas = [ + range3d.size_x() * range3d.size_y(), # front/back + range3d.size_x() * range3d.size_z(), # top/bottom + range3d.size_y() * range3d.size_z(), # sides + ] + area_accumulator = [] + for ix in range(6): + area_ix = ix % 3 + if ix == 0: + area_accumulator.append(face_areas[area_ix]) + else: + area_accumulator.append(face_areas[area_ix] + area_accumulator[-1]) + + normalized_area_accumulator = [x / area_accumulator[-1] for x in area_accumulator] + + return normalized_area_accumulator + + +def get_range3d_sample_planes(range3d: mn.Range3D): + """ + Get origin and basis vectors for each face's sample planes. + """ + # For each face a starting point and two edge vectors (un-normalized) + face_info: List[Tuple[mn.Vector3, mn.Vector3, mn.Vector3]] = [ + ( + range3d.front_bottom_left, + mn.Vector3.x_axis(range3d.size_x()), + mn.Vector3.y_axis(range3d.size_y()), + ), # front + ( + range3d.back_top_left, + mn.Vector3.x_axis(range3d.size_x()), + mn.Vector3.z_axis(range3d.size_z()), + ), # top + ( + range3d.back_bottom_left, + mn.Vector3.y_axis(range3d.size_y()), + mn.Vector3.z_axis(range3d.size_z()), + ), # left + ( + range3d.back_bottom_left, + mn.Vector3.x_axis(range3d.size_x()), + mn.Vector3.y_axis(range3d.size_y()), + ), # back + ( + range3d.back_bottom_left, + mn.Vector3.x_axis(range3d.size_x()), + mn.Vector3.z_axis(range3d.size_z()), + ), # bottom + ( + range3d.back_bottom_right, + mn.Vector3.y_axis(range3d.size_y()), + mn.Vector3.z_axis(range3d.size_z()), + ), # right + ] + return face_info + + +def sample_jittered_points_from_range3d(range3d: mn.Range3D, num_points: int = 100): + """ + Use jittered sampling to compute a more uniformly distributed set of random points. + """ + normalized_area_accumulator = compute_area_weights_for_range3d_faces(range3d) + normalized_areas = [] + for vix in range(len(normalized_area_accumulator)): + if vix == 0: + normalized_areas.append(normalized_area_accumulator[vix]) + else: + normalized_areas.append( + normalized_area_accumulator[vix] - normalized_area_accumulator[vix - 1] + ) + + # get number of points per face based on area + # NOTE: rounded up, so may be slightly more points than requested. + points_per_face = [max(1, math.ceil(x * num_points)) for x in normalized_areas] + + # get face plane basis + face_info = get_range3d_sample_planes(range3d) + + # one internal list of each face of the box: + samples = [] + for _ in range(6): + samples.append([]) + + real_total_points = 0 + # print("Sampling Stats: ") + # for each face, jittered sample of total area: + for face_ix, f in enumerate(face_info): + # get ratio of width/height in local space to plan jittering + aspect_ratio = f[1].length() / f[2].length() + num_wide = max(1, int(math.sqrt(aspect_ratio * points_per_face[face_ix]))) + num_high = max(1, int((points_per_face[face_ix] + num_wide - 1) / num_wide)) + total_points = num_wide * num_high + real_total_points += total_points + # print(f" f_{face_ix}: ") + # print(f" points_per_face = {points_per_face[face_ix]}") + # print(f" aspect_ratio = {aspect_ratio}") + # print(f" num_wide = {num_wide}") + # print(f" num_high = {num_high}") + # print(f" total_points = {total_points}") + + # get jittered cell sizes + dx = f[1] / num_wide + dy = f[2] / num_high + for x in range(num_wide): + for y in range(num_high): + # get cell origin + org = f[0] + x * dx + y * dy + # point is randomly placed in the cell + point = org + random.random() * dx + random.random() * dy + samples[face_ix].append(point) + # print(f" real_total_points = {real_total_points}") + + return samples + + +def sample_points_from_range3d( + range3d: mn.Range3D, num_points: int = 100 +) -> List[List[mn.Vector3]]: + """ + Sample 'num_points' from the surface of a box defeined by 'range3d'. + """ + + # ----------------------------------------- + # area weighted face sampling + normalized_area_accumulator = compute_area_weights_for_range3d_faces(range3d) + + def sample_face() -> int: + """ + Weighted sampling of a face from the area accumulator. + """ + rand = random.random() + for ix in range(6): + if normalized_area_accumulator[ix] > rand: + return ix + raise (AssertionError, "Should not reach here.") + + # ----------------------------------------- + + face_info = get_range3d_sample_planes(range3d) + + # one internal list of each face of the box: + samples = [] + for _ in range(6): + samples.append([]) + + # sample points for random faces + for _ in range(num_points): + face_ix = sample_face() + f = face_info[face_ix] + point = f[0] + random.random() * f[1] + random.random() * f[2] + samples[face_ix].append(point) + + return samples + + +# End - Range3D surface sampling utils +# ======================================================================= + + +def sample_points_from_sphere( + center: mn.Vector3, + radius: float, + num_points: int = 100, +) -> List[List[mn.Vector3]]: + """ + Sample num_points from a sphere defined by center and radius. + Return all points in two identical lists to indicate pairwise raycasting. + :param center: sphere center position + :param radius: sphere radius + :param num_points: number of points to sample + """ + samples = [] + + # sample points + while len(samples) < num_points: + # rejection sample unit sphere from volume + rand_point = np.random.random(3) * 2.0 - np.ones(1) + vec_len = np.linalg.norm(rand_point) + if vec_len < 1.0: + # inside the sphere, so project to the surface + samples.append(mn.Vector3(rand_point / vec_len)) + # else outside the sphere, so rejected + + # move from unit sphere to input sphere + samples = [x * radius + center for x in samples] + + # collect into pairwise datastructure + samples = [samples, samples] + + return samples + + +def receptacle_density_sample( + sim: habitat_sim.simulator.Simulator, + receptacle: hab_receptacle.TriangleMeshReceptacle, + target_radius: float = 0.04, + max_points: int = 100, + min_points: int = 5, + max_tries: int = 200, +): + target_point_area = math.pi * target_radius**2 + expected_points = receptacle.total_area / target_point_area + + # if necessary, compute new target_radius to best cover the area + if expected_points > max_points or expected_points < min_points: + expected_points = max(min_points, min(max_points, expected_points)) + target_radius = math.sqrt(receptacle.total_area / (expected_points * math.pi)) + + # print(f"receptacle_density_sample(`{receptacle.name}`): area={receptacle.total_area}, r={target_radius}, num_p={expected_points}") + + sampled_points = [] + num_tries = 0 + min_dist = target_radius * 2 + while len(sampled_points) < expected_points and num_tries < max_tries: + sample_point = receptacle.sample_uniform_global(sim, sample_region_scale=1.0) + success = True + for existing_point in sampled_points: + if (sample_point - existing_point).length() < min_dist: + num_tries += 1 + success = False + break + if success: + # print(f" success {sample_point} in {num_tries} tries") + + # if no rejection, add the point + sampled_points.append(sample_point) + num_tries = 0 + + # print(f" found {len(sampled_points)}/{expected_points} points.") + + return sampled_points, target_radius + + +def run_pairwise_raycasts( + points: List[List[mn.Vector3]], + sim: habitat_sim.Simulator, + min_dist: float = 0.05, + discard_invalid_results: bool = True, +) -> List[habitat_sim.physics.RaycastResults]: + """ + Raycast between each pair of points from different surfaces. + :param min_dist: The minimum ray distance to allow. Cull all candidate pairs closer than this distance. + :param discard_invalid_results: If true, discard ray hit distances > 1 + """ + ray_max_local_dist = 100.0 # default + if discard_invalid_results: + # disallow contacts outside of the bounding volume (shouldn't happen anyway...) + ray_max_local_dist = 1.0 + all_raycast_results: List[habitat_sim.physics.RaycastResults] = [] + print("Rays detected with invalid hit distance: ") + for fix0 in range(len(points)): + for fix1 in range(len(points)): + if fix0 != fix1: # no pairs on the same face + for p0 in points[fix0]: + for p1 in points[fix1]: + if (p0 - p1).length() > min_dist: + # this is a valid pair of points + ray = habitat_sim.geo.Ray(p0, p1 - p0) # origin, direction + # raycast + all_raycast_results.append( + sim.cast_ray(ray=ray, max_distance=ray_max_local_dist) + ) + # reverse direction as separate entry (because exiting a convex does not generate a hit record) + ray2 = habitat_sim.geo.Ray(p1, p0 - p1) # origin, direction + # raycast + all_raycast_results.append( + sim.cast_ray(ray=ray2, max_distance=ray_max_local_dist) + ) + + # prints invalid rays if not discarded by discard_invalid_results==True + for ix in [-1, -2]: + if all_raycast_results[ix].has_hits() and ( + all_raycast_results[ix].hits[0].ray_distance > 1 + or all_raycast_results[ix].hits[0].ray_distance < 0 + ): + print( + f" distance={all_raycast_results[ix].hits[0].ray_distance}" + ) + + return all_raycast_results + + +def debug_draw_raycast_results( + sim, ground_truth_results, proxy_results, subsample_number: int = 100, seed=0 +) -> None: + """ + Render debug lines for a subset of raycast results, randomly subsampled for efficiency. + """ + random.seed(seed) + red = mn.Color4.red() + yellow = mn.Color4.yellow() + blue = mn.Color4.blue() + grey = mn.Color4(mn.Vector3(0.6), 1.0) + for _ in range(subsample_number): + result_ix = random.randint(0, len(ground_truth_results) - 1) + ray = ground_truth_results[result_ix].ray + gt_results = ground_truth_results[result_ix] + pr_results = proxy_results[result_ix] + + if gt_results.has_hits() or pr_results.has_hits(): + # some logic for line colors + first_hit_dist = 0 + # pairs of distances for overshooting the ground truth and undershooting the ground truth + overshoot_dists = [] + undershoot_dists = [] + + # draw first hits for gt and proxy + if gt_results.has_hits(): + sim.get_debug_line_render().draw_circle( + translation=ray.origin + + ray.direction * gt_results.hits[0].ray_distance, + radius=0.005, + color=blue, + normal=gt_results.hits[0].normal, + ) + if pr_results.has_hits(): + sim.get_debug_line_render().draw_circle( + translation=ray.origin + + ray.direction * pr_results.hits[0].ray_distance, + radius=0.005, + color=yellow, + normal=pr_results.hits[0].normal, + ) + + if not gt_results.has_hits(): + first_hit_dist = pr_results.hits[0].ray_distance + overshoot_dists.append((first_hit_dist, 1.0)) + elif not pr_results.has_hits(): + first_hit_dist = gt_results.hits[0].ray_distance + undershoot_dists.append((first_hit_dist, 1.0)) + else: + # both have hits + first_hit_dist = min( + gt_results.hits[0].ray_distance, pr_results.hits[0].ray_distance + ) + + # compute overshoots and undershoots for first hit: + if gt_results.hits[0].ray_distance < pr_results.hits[0].ray_distance: + # undershoot + undershoot_dists.append( + ( + gt_results.hits[0].ray_distance, + pr_results.hits[0].ray_distance, + ) + ) + else: + # overshoot + overshoot_dists.append( + ( + gt_results.hits[0].ray_distance, + pr_results.hits[0].ray_distance, + ) + ) + + # draw blue lines for overlapping distances + sim.get_debug_line_render().draw_transformed_line( + ray.origin, ray.origin + ray.direction * first_hit_dist, blue + ) + + # draw red lines for overshoots (proxy is outside the ground truth) + for d0, d1 in overshoot_dists: + sim.get_debug_line_render().draw_transformed_line( + ray.origin + ray.direction * d0, + ray.origin + ray.direction * d1, + red, + ) + + # draw yellow lines for undershoots (proxy is inside the ground truth) + for d0, d1 in undershoot_dists: + sim.get_debug_line_render().draw_transformed_line( + ray.origin + ray.direction * d0, + ray.origin + ray.direction * d1, + yellow, + ) + + else: + # no hits, grey line + sim.get_debug_line_render().draw_transformed_line( + ray.origin, ray.origin + ray.direction, grey + ) + + +def get_raycast_results_cumulative_error_metric( + ground_truth_results, proxy_results +) -> float: + """ + Generates a scalar error metric from raycast results normalized to [0,1]. + + absolute_error = sum(abs(gt_1st_hit_dist-pr_1st_hit_dist)) + + To normalize error: + 0 corresponds to gt distances (absolute_error == 0) + 1 corresponds to max error. For each ray, max error is max(gt_1st_hit_dist, ray_length-gt_1st_hit_dist). + max_error = sum(max(gt_1st_hit_dist, ray_length-gt_1st_hit_dist)) + normalized_error = error/max_error + """ + assert len(ground_truth_results) == len( + proxy_results + ), "raycast results must be equivalent." + + max_error = 0 + absolute_error = 0 + for r_ix in range(len(ground_truth_results)): + ray = ground_truth_results[r_ix].ray + ray_len = ray.direction.length() + local_max_error = ray_len + gt_dist = ray_len + if ground_truth_results[r_ix].has_hits(): + gt_dist = ground_truth_results[r_ix].hits[0].ray_distance * ray_len + local_max_error = max(gt_dist, ray_len - gt_dist) + max_error += local_max_error + local_proxy_dist = ray_len + if proxy_results[r_ix].has_hits(): + local_proxy_dist = proxy_results[r_ix].hits[0].ray_distance * ray_len + local_absolute_error = abs(local_proxy_dist - gt_dist) + absolute_error += local_absolute_error + + normalized_error = absolute_error / max_error + return normalized_error + + +# =================================================================== +# CollisionProxyOptimizer class provides a stateful API for +# configurable evaluation and optimization of collision proxy shapes. +# =================================================================== + + +class CollisionProxyOptimizer: + """ + Stateful control flow for using Habitat-sim to evaluate and optimize collision proxy shapes. + """ + + def __init__( + self, + sim_settings: Dict[str, Any], + output_directory: Optional[str] = None, + mm: Optional[habitat_sim.metadata.MetadataMediator] = None, + ) -> None: + # load the dataset into a persistent, shared MetadataMediator instance. + self.mm = mm if mm is not None else habitat_sim.metadata.MetadataMediator() + self.mm.active_dataset = sim_settings["scene_dataset_config_file"] + self.sim_settings = sim_settings.copy() + + # path to the desired output directory for images/csv + self.output_directory = output_directory + if output_directory is not None: + os.makedirs(self.output_directory, exist_ok=True) + + # if true, render and save debug images in self.output_directory + self.generate_debug_images = False + + # option to use Receptacle annotations to compute an additional accuracy metric + self.compute_receptacle_useability_metrics = True + # add a vertical epsilon offset to the receptacle points for analysis. This is added directly to the sampled points. + self.rec_point_vertical_offset = 0.041 + + self.init_caches() + + def init_caches(self): + """ + Re-initialize all internal data caches to prepare for re-use. + """ + # cache of test points, rays, distances, etc... for use by active processes + # NOTE: entries created by `setup_obj_gt` and cleaned by `clean_obj_gt` for memory efficiency. + self.gt_data: Dict[str, Dict[str, Any]] = {} + + # cache global results to be written to csv. + self.results: Dict[str, Dict[str, Any]] = {} + + def get_proxy_index(self, obj_handle: str) -> int: + """ + Get the current proxy index for an object. + """ + return self.gt_data[obj_handle]["proxy_index"] + + def increment_proxy_index(self, obj_handle: str) -> int: + """ + Increment the current proxy index. + Only do this after all processing for the current proxy is complete. + """ + self.gt_data[obj_handle]["proxy_index"] += 1 + + def get_proxy_shape_id(self, obj_handle: str) -> str: + """ + Get a string representation of the current proxy shape. + """ + return f"pr{self.get_proxy_index(obj_handle)}" + + def get_cfg_with_mm( + self, scene: str = "NONE" + ) -> habitat_sim.simulator.Configuration: + """ + Get a Configuration object for initializing habitat_sim Simulator object with the correct dataset and MetadataMediator passed along. + + :param scene: The desired scene entry, defaulting to the empty NONE scene. + """ + sim_settings = self.sim_settings.copy() + sim_settings["scene_dataset_config_file"] = self.mm.active_dataset + sim_settings["scene"] = scene + cfg = make_cfg(sim_settings) + cfg.metadata_mediator = self.mm + return cfg + + def setup_obj_gt( + self, + obj_handle: str, + sample_shape: str = "jittered_aabb", + num_point_samples=100, + ) -> None: + """ + Prepare the ground truth and sample point sets for an object. + """ + assert ( + obj_handle not in self.gt_data + ), f"`{obj_handle}` already setup in gt_data: {self.gt_data.keys()}" + + # find object + otm = self.mm.object_template_manager + obj_template = otm.get_template_by_handle(obj_handle) + assert obj_template is not None, f"Could not find object handle `{obj_handle}`" + + # create a stage template with the object's render mesh as a "ground truth" for metrics + stm = self.mm.stage_template_manager + stage_template_name = obj_handle + "_as_stage" + new_stage_template = stm.create_new_template(handle=stage_template_name) + new_stage_template.render_asset_handle = obj_template.render_asset_handle + new_stage_template.orient_up = obj_template.orient_up + new_stage_template.orient_front = obj_template.orient_front + stm.register_template( + template=new_stage_template, specified_handle=stage_template_name + ) + + # initialize the object's runtime data cache + self.gt_data[obj_handle] = { + "proxy_index": 0, # used to recover and increment `shape_id` during optimization and evaluation + "stage_template_name": stage_template_name, + "receptacles": {}, # sub-cache for receptacle metric data and results + "raycasts": {}, # subcache for shape raycasting metric data + "shape_test_results": { + "gt": {} + }, # subcache for shape and physics metric results + } + + # correct now for any COM automation + obj_template.compute_COM_from_shape = False + obj_template.com = mn.Vector3(0) + otm.register_template(obj_template) + + if self.compute_receptacle_useability_metrics or self.generate_debug_images: + # pre-process the ground truth object and receptacles + rec_vertical_offset = mn.Vector3(0, self.rec_point_vertical_offset, 0) + cfg = self.get_cfg_with_mm() + with habitat_sim.Simulator(cfg) as sim: + # load the gt object + rom = sim.get_rigid_object_manager() + obj = rom.add_object_by_template_handle(obj_handle) + assert obj.is_alive, "Object was not added correctly." + + if self.compute_receptacle_useability_metrics: + # get receptacles defined for the object: + source_template_file = obj.creation_attributes.file_directory + user_attr = obj.user_attributes + obj_receptacles = hab_receptacle.parse_receptacles_from_user_config( + user_attr, + parent_object_handle=obj.handle, + parent_template_directory=source_template_file, + ) + + # sample test points on the receptacles + for receptacle in obj_receptacles: + if type(receptacle) == hab_receptacle.TriangleMeshReceptacle: + rec_test_points = [] + t_radius = 0.01 + # adaptive density sample: + rec_test_points, t_radius = receptacle_density_sample( + sim, receptacle + ) + # add the vertical offset + rec_test_points = [ + p + rec_vertical_offset for p in rec_test_points + ] + + # random sample: + # for _ in range(num_point_samples): + # rec_test_points.append( + # receptacle.sample_uniform_global( + # sim, sample_region_scale=1.0 + # ) + # ) + self.gt_data[obj_handle]["receptacles"][receptacle.name] = { + "sample_points": rec_test_points, + "shape_id_results": {}, + } + if self.generate_debug_images: + debug_lines = [] + for face in range( + int(len(receptacle.mesh_data.indices) / 3) + ): + verts = receptacle.get_face_verts(f_ix=face) + for edge in range(3): + debug_lines.append( + ( + [verts[edge], verts[(edge + 1) % 3]], + mn.Color4.green(), + ) + ) + debug_circles = [] + for p in rec_test_points: + debug_circles.append( + ( + ( + p, # center + t_radius, # radius + mn.Vector3(0, 1, 0), # normal + mn.Color4.red(), # color + ) + ) + ) + if ( + self.generate_debug_images + and self.output_directory is not None + ): + # use DebugVisualizer to get 6-axis view of the object + dvb = hab_debug_vis.DebugVisualizer( + sim=sim, + output_path=self.output_directory, + default_sensor_uuid="color_sensor", + ) + dvb.peek_rigid_object( + obj, + peek_all_axis=True, + additional_savefile_prefix=f"{receptacle.name}_", + debug_lines=debug_lines, + debug_circles=debug_circles, + ) + + if self.generate_debug_images and self.output_directory is not None: + # use DebugVisualizer to get 6-axis view of the object + dvb = hab_debug_vis.DebugVisualizer( + sim=sim, + output_path=self.output_directory, + default_sensor_uuid="color_sensor", + ) + dvb.peek_rigid_object( + obj, peek_all_axis=True, additional_savefile_prefix="gt_" + ) + + # load a simulator instance with this object as the stage + cfg = self.get_cfg_with_mm(scene=stage_template_name) + with habitat_sim.Simulator(cfg) as sim: + # get test points from bounding box info: + scene_bb = sim.get_active_scene_graph().get_root_node().cumulative_bb + inflated_scene_bb = scene_bb.scaled(mn.Vector3(1.25)) + inflated_scene_bb = mn.Range3D.from_center( + scene_bb.center(), inflated_scene_bb.size() / 2.0 + ) + # NOTE: to save the referenced Range3D object, we need to deep or Magnum will destroy the underlying C++ objects. + self.gt_data[obj_handle]["scene_bb"] = mn.Range3D( + scene_bb.min, scene_bb.max + ) + self.gt_data[obj_handle]["inflated_scene_bb"] = inflated_scene_bb + test_points = None + if sample_shape == "aabb": + # bounding box sample + test_points = sample_points_from_range3d( + range3d=inflated_scene_bb, num_points=num_point_samples + ) + elif sample_shape == "jittered_aabb": + # bounding box sample + test_points = sample_jittered_points_from_range3d( + range3d=inflated_scene_bb, num_points=num_point_samples + ) + elif sample_shape == "sphere": + # bounding sphere sample + half_diagonal = (scene_bb.max - scene_bb.min).length() / 2.0 + test_points = sample_points_from_sphere( + center=inflated_scene_bb.center(), + radius=half_diagonal, + num_points=num_point_samples, + ) + else: + raise NotImplementedError( + f"sample_shape == `{sample_shape}` is not implemented. Use `sphere` or `aabb`." + ) + self.gt_data[obj_handle]["test_points"] = test_points + + # compute and cache "ground truth" raycast on object as stage + gt_raycast_results = run_pairwise_raycasts(test_points, sim) + self.gt_data[obj_handle]["raycasts"]["gt"] = gt_raycast_results + + def clean_obj_gt(self, obj_handle: str) -> None: + """ + Cleans the global object cache to better manage process memory. + Call this to clean-up after global data are written and detailed sample data are no longer necessary. + """ + assert ( + obj_handle in self.gt_data + ), f"`{obj_handle}` does not have any entry in gt_data: {self.gt_data.keys()}. Call to `setup_obj_gt(obj_handle)` required." + self.gt_data.pop(obj_handle) + + def compute_baseline_metrics(self, obj_handle: str) -> None: + """ + Computes 2 baselines for the evaluation metric and caches the results: + 1. No collision object + 2. AABB collision object + """ + assert ( + obj_handle in self.gt_data + ), f"`{obj_handle}` does not have any entry in gt_data: {self.gt_data.keys()}. Call to `setup_obj_gt(obj_handle)` required." + + # start with empty scene + cfg = self.get_cfg_with_mm() + with habitat_sim.Simulator(cfg) as sim: + empty_raycast_results = run_pairwise_raycasts( + self.gt_data[obj_handle]["test_points"], sim + ) + self.gt_data[obj_handle]["raycasts"]["empty"] = empty_raycast_results + + cfg = self.get_cfg_with_mm() + with habitat_sim.Simulator(cfg) as sim: + # modify the template + obj_template = sim.get_object_template_manager().get_template_by_handle( + obj_handle + ) + assert ( + obj_template is not None + ), f"Could not find object handle `{obj_handle}`" + # bounding box as collision object + obj_template.bounding_box_collisions = True + sim.get_object_template_manager().register_template(obj_template) + + # load the object + sim.get_rigid_object_manager().add_object_by_template_handle(obj_handle) + + # run evaluation + bb_raycast_results = run_pairwise_raycasts( + self.gt_data[obj_handle]["test_points"], sim + ) + self.gt_data[obj_handle]["raycasts"]["bb"] = bb_raycast_results + + # un-modify the template + obj_template.bounding_box_collisions = False + sim.get_object_template_manager().register_template(obj_template) + + def compute_proxy_metrics(self, obj_handle: str) -> None: + """ + Computes the evaluation metric on the currently configred proxy shape and caches the results. + """ + assert ( + obj_handle in self.gt_data + ), f"`{obj_handle}` does not have any entry in gt_data: {self.gt_data.keys()}. Call to `setup_obj_gt(obj_handle)` required." + + # when evaluating multiple proxy shapes, need unique ids: + pr_id = self.get_proxy_shape_id(obj_handle) + + # start with empty scene + cfg = self.get_cfg_with_mm() + with habitat_sim.Simulator(cfg) as sim: + # modify the template to render collision object + otm = self.mm.object_template_manager + obj_template = otm.get_template_by_handle(obj_handle) + render_asset = obj_template.render_asset_handle + obj_template.render_asset_handle = obj_template.collision_asset_handle + otm.register_template(obj_template) + + # load the object + obj = sim.get_rigid_object_manager().add_object_by_template_handle( + obj_handle + ) + assert obj.is_alive, "Object was not added correctly." + + # check that collision shape bounding box is similar + col_bb = obj.root_scene_node.cumulative_bb + assert self.gt_data[obj_handle]["inflated_scene_bb"].contains( + col_bb.min + ) and self.gt_data[obj_handle]["inflated_scene_bb"].contains( + col_bb.max + ), f"Inflated bounding box does not contain the collision shape. (Object `{obj_handle}`)" + + if self.generate_debug_images and self.output_directory is not None: + # use DebugVisualizer to get 6-axis view of the object + dvb = hab_debug_vis.DebugVisualizer( + sim=sim, + output_path=self.output_directory, + default_sensor_uuid="color_sensor", + ) + dvb.peek_rigid_object( + obj, peek_all_axis=True, additional_savefile_prefix=pr_id + "_" + ) + + # run evaluation + pr_raycast_results = run_pairwise_raycasts( + self.gt_data[obj_handle]["test_points"], sim + ) + self.gt_data[obj_handle]["raycasts"][pr_id] = pr_raycast_results + + # undo template modification + obj_template.render_asset_handle = render_asset + otm.register_template(obj_template) + + def compute_receptacle_access_metrics( + self, obj_handle: str, use_gt=False, acces_ratio_threshold: float = 0.1 + ): + """ + Compute a heuristic for the accessibility of all Receptacles for an object. + Uses raycasting from previously sampled receptacle locations to approximate how open a particular receptacle is. + :param use_gt: Compute the metric for the ground truth shape instead of the currently active collision proxy (default) + :param acces_ratio_threshold: The ratio of accessible:blocked rays necessary for a Receptacle point to be considered accessible + """ + # algorithm: + # For each receptacle, r: + # For each sample point, s: + # Generate `num_point_rays` directions, d (length bb diagnonal) and Ray(origin=s+d, direction=d) + # For each ray: + # If dist > 1, success, otherwise failure + # + # metrics: + # - %rays + # - %points w/ success% > eps(10%) #call these successful/accessible + # - average % for points + # ? how to get regions? + # ? debug draw this metric? + # ? how to diff b/t gt and pr? + + print(f"compute_receptacle_access_metrics - obj_handle = {obj_handle}") + + # start with empty scene or stage as scene: + scene_name = "NONE" + if use_gt: + scene_name = self.gt_data[obj_handle]["stage_template_name"] + cfg = self.get_cfg_with_mm(scene=scene_name) + with habitat_sim.Simulator(cfg) as sim: + obj_rec_data = self.gt_data[obj_handle]["receptacles"] + shape_id = "gt" + obj = None + if not use_gt: + # load the object + obj = sim.get_rigid_object_manager().add_object_by_template_handle( + obj_handle + ) + assert obj.is_alive, "Object was not added correctly." + + # when evaluating multiple proxy shapes, need unique ids: + shape_id = self.get_proxy_shape_id(obj_handle) + + # gather hemisphere rays scaled to object's size + # NOTE: because the receptacle points can be located anywhere in the bounding box, raycast radius must be bb diagonal length + ray_sphere_radius = self.gt_data[obj_handle]["scene_bb"].size().length() + assert ray_sphere_radius > 0, "otherwise we have an error" + ray_sphere_points = get_scaled_hemisphere_vectors(ray_sphere_radius) + + # save a list of point accessibility scores for debugging and visualization + receptacle_point_access_scores = {} + dvb: Optional[hab_debug_vis.DebugVisualizer] = None + if self.output_directory is not None: + dvb = hab_debug_vis.DebugVisualizer( + sim=sim, + output_path=self.output_directory, + default_sensor_uuid="color_sensor", + ) + + # collect hemisphere raycast samples for all receptacle sample points + for receptacle_name in obj_rec_data: + sample_point_ray_results: List[ + List[habitat_sim.physics.RaycastResults] + ] = [] + sample_point_access_ratios: List[float] = [] + # access rate is percent of "accessible" points apssing the threshold + receptacle_access_rate = 0 + # access score is average accessibility of points + receptacle_access_score = 0 + sample_points = obj_rec_data[receptacle_name]["sample_points"] + for sample_point in sample_points: + # NOTE: rays must originate outside the shape because origins inside a convex will not collide. + # move ray origins to new point location + hemi_rays = [ + habitat_sim.geo.Ray(v + sample_point, -v) + for v in ray_sphere_points + ] + # rays are not unit length, so use local max_distance==1 ray length + ray_results = [ + sim.cast_ray(ray=ray, max_distance=1.0) for ray in hemi_rays + ] + sample_point_ray_results.append(ray_results) + + # compute per-point access metrics + blocked_rays = len([rr for rr in ray_results if rr.has_hits()]) + sample_point_access_ratios.append( + (len(ray_results) - blocked_rays) / len(ray_results) + ) + receptacle_access_score += sample_point_access_ratios[-1] + if sample_point_access_ratios[-1] > acces_ratio_threshold: + receptacle_access_rate += 1 + receptacle_point_access_scores[ + receptacle_name + ] = sample_point_access_ratios + + receptacle_access_score /= len(sample_points) + receptacle_access_rate /= len(sample_points) + + if shape_id not in obj_rec_data[receptacle_name]["shape_id_results"]: + obj_rec_data[receptacle_name]["shape_id_results"][shape_id] = {} + assert ( + "access_results" + not in obj_rec_data[receptacle_name]["shape_id_results"][shape_id] + ), f"Overwriting existing 'access_results' data for '{receptacle_name}'|'{shape_id}'." + obj_rec_data[receptacle_name]["shape_id_results"][shape_id][ + "access_results" + ] = { + "receptacle_point_access_scores": receptacle_point_access_scores[ + receptacle_name + ], + "sample_point_ray_results": sample_point_ray_results, + "receptacle_access_score": receptacle_access_score, + "receptacle_access_rate": receptacle_access_rate, + } + access_results = obj_rec_data[receptacle_name]["shape_id_results"][ + shape_id + ]["access_results"] + + print(f" receptacle_name = {receptacle_name}") + print(f" receptacle_access_score = {receptacle_access_score}") + print(f" receptacle_access_rate = {receptacle_access_rate}") + + if self.generate_debug_images and dvb is not None: + # generate receptacle access debug images + # 1a Show missed rays vs 1b hit rays + debug_lines = [] + for ray_results in access_results["sample_point_ray_results"]: + for hit_record in ray_results: + if not hit_record.has_hits(): + debug_lines.append( + ( + [ + hit_record.ray.origin, + hit_record.ray.origin + + hit_record.ray.direction, + ], + mn.Color4.green(), + ) + ) + if use_gt: + dvb.peek_scene( + peek_all_axis=True, + additional_savefile_prefix=f"gt_{receptacle_name}_access_rays_", + debug_lines=debug_lines, + debug_circles=None, + ) + else: + dvb.peek_rigid_object( + obj, + peek_all_axis=True, + additional_savefile_prefix=f"{shape_id}_{receptacle_name}_access_rays_", + debug_lines=debug_lines, + debug_circles=None, + ) + + # 2 Show only rec points colored by "access" metric or percentage + debug_circles = [] + color_r = mn.Color4.red().to_xyz() + color_g = mn.Color4.green().to_xyz() + delta = color_g - color_r + for point_access_ratio, point in zip( + receptacle_point_access_scores[receptacle_name], + obj_rec_data[receptacle_name]["sample_points"], + ): + point_color_xyz = color_r + delta * point_access_ratio + debug_circles.append( + ( + point, + 0.02, + mn.Vector3(0, 1, 0), + mn.Color4.from_xyz(point_color_xyz), + ) + ) + # use DebugVisualizer to get 6-axis view of the object + if use_gt: + dvb.peek_scene( + peek_all_axis=True, + additional_savefile_prefix=f"gt_{receptacle_name}_point_ratios_", + debug_lines=None, + debug_circles=debug_circles, + ) + else: + dvb.peek_rigid_object( + obj, + peek_all_axis=True, + additional_savefile_prefix=f"{shape_id}_{receptacle_name}_point_ratios_", + debug_lines=None, + debug_circles=debug_circles, + ) + # obj_rec_data[receptacle_name]["results"][shape_id]["sample_point_ray_results"] + + def construct_cylinder_object( + self, + mm: habitat_sim.metadata.MetadataMediator, + cyl_radius: float = 0.04, + cyl_height: float = 0.15, + ): + constructed_cyl_temp_name = "scaled_cyl_template" + otm = mm.object_template_manager + cyl_temp_handle = otm.get_synth_template_handles("cylinder")[0] + cyl_temp = otm.get_template_by_handle(cyl_temp_handle) + cyl_temp.scale = mn.Vector3(cyl_radius, cyl_height / 2.0, cyl_radius) + otm.register_template(cyl_temp, constructed_cyl_temp_name) + return constructed_cyl_temp_name + + def compute_receptacle_stability( + self, + obj_handle: str, + use_gt: bool = False, + cyl_radius: float = 0.04, + cyl_height: float = 0.15, + accepted_height_error: float = 0.1, + ): + """ + Try to place a dynamic cylinder on the receptacle points. Record snap error and physical stability. + + :param obj_handle: The object to evaluate. + :param use_gt: Compute the metric for the ground truth shape instead of the currently active collision proxy (default) + :param cyl_radius: Radius of the test cylinder object (default similar to food can) + :param cyl_height: Height of the test cylinder object (default similar to food can) + :param accepted_height_error: The acceptacle distance from receptacle to snapped point considered successful (meters) + """ + + constructed_cyl_obj_handle = self.construct_cylinder_object( + self.mm, cyl_radius, cyl_height + ) + + assert ( + len(self.gt_data[obj_handle]["receptacles"].keys()) > 0 + ), "Object must have receptacle sampling metadata defined. See `setup_obj_gt`" + + # start with empty scene or stage as scene: + scene_name = "NONE" + if use_gt: + scene_name = self.gt_data[obj_handle]["stage_template_name"] + cfg = self.get_cfg_with_mm(scene=scene_name) + with habitat_sim.Simulator(cfg) as sim: + dvb: Optional[hab_debug_vis.DebugVisualizer] = None + if self.generate_debug_images and self.output_directory is not None: + dvb = hab_debug_vis.DebugVisualizer( + sim=sim, + output_path=self.output_directory, + default_sensor_uuid="color_sensor", + ) + # load the object + rom = sim.get_rigid_object_manager() + obj = None + support_obj_ids = [-1] + shape_id = "gt" + if not use_gt: + obj = rom.add_object_by_template_handle(obj_handle) + support_obj_ids = [obj.object_id] + assert obj.is_alive, "Object was not added correctly." + # need to make the object STATIC so it doesn't move + obj.motion_type = habitat_sim.physics.MotionType.STATIC + # when evaluating multiple proxy shapes, need unique ids: + shape_id = self.get_proxy_shape_id(obj_handle) + + # add the test object + cyl_test_obj = rom.add_object_by_template_handle(constructed_cyl_obj_handle) + cyl_test_obj_com_height = cyl_test_obj.root_scene_node.cumulative_bb.max[1] + assert cyl_test_obj.is_alive, "Test object was not added correctly." + + # we sample above the receptacle to account for margin, but we compare distance to the actual receptacle height + receptacle_sample_height_correction = mn.Vector3( + 0, -self.rec_point_vertical_offset, 0 + ) + + # evaluation the sample points for each receptacle + rec_data = self.gt_data[obj_handle]["receptacles"] + for rec_name in rec_data: + sample_points = rec_data[rec_name]["sample_points"] + + failed_snap = 0 + failed_by_distance = 0 + failed_unstable = 0 + point_stabilities = [] + for sample_point in sample_points: + cyl_test_obj.translation = sample_point + cyl_test_obj.rotation = mn.Quaternion.identity_init() + # snap check + success = snap_down( + sim, cyl_test_obj, support_obj_ids=support_obj_ids, vdb=dvb + ) + if success: + expected_height_error = abs( + ( + cyl_test_obj.translation + - (sample_point + receptacle_sample_height_correction) + ).length() + - cyl_test_obj_com_height + ) + if expected_height_error > accepted_height_error: + failed_by_distance += 1 + point_stabilities.append(False) + continue + + # physical stability analysis + snap_position = cyl_test_obj.translation + identity_q = mn.Quaternion.identity_init() + displacement_limit = 0.04 # meters + rotation_limit = mn.Rad(0.1) # radians + max_sim_time = 3.0 + dt = 0.5 + start_time = sim.get_world_time() + object_is_stable = True + while sim.get_world_time() - start_time < max_sim_time: + sim.step_world(dt) + linear_displacement = ( + cyl_test_obj.translation - snap_position + ).length() + # NOTE: negative quaternion represents the same rotation, but gets a different angle error so check both + angular_displacement = min( + mn.math.half_angle(cyl_test_obj.rotation, identity_q), + mn.math.half_angle( + -1 * cyl_test_obj.rotation, identity_q + ), + ) + if ( + angular_displacement > rotation_limit + or linear_displacement > displacement_limit + ): + object_is_stable = False + break + if not cyl_test_obj.awake: + # the object has settled, no need to continue simulating + break + # NOTE: we assume that if the object has not moved past the threshold in 'max_sim_time', then it must be stabel enough + if not object_is_stable: + failed_unstable += 1 + point_stabilities.append(False) + else: + point_stabilities.append(True) + else: + failed_snap += 1 + point_stabilities.append(False) + + successful_points = ( + len(sample_points) + - failed_snap + - failed_by_distance + - failed_unstable + ) + success_ratio = successful_points / len(sample_points) + print( + f"{shape_id}: receptacle '{rec_name}' success_ratio = {success_ratio}" + ) + print( + f" failed_snap = {failed_snap}|failed_by_distance = {failed_by_distance}|failed_unstable={failed_unstable}|total={len(sample_points)}" + ) + # TODO: visualize this error + + # write results to cache + if shape_id not in rec_data[rec_name]["shape_id_results"]: + rec_data[rec_name]["shape_id_results"][shape_id] = {} + assert ( + "stability_results" + not in rec_data[rec_name]["shape_id_results"][shape_id] + ), f"Overwriting existing 'stability_results' data for '{rec_name}'|'{shape_id}'." + rec_data[rec_name]["shape_id_results"][shape_id][ + "stability_results" + ] = { + "success_ratio": success_ratio, + "failed_snap": failed_snap, + "failed_by_distance": failed_by_distance, + "failed_unstable": failed_unstable, + "total": len(sample_points), + "point_stabilities": point_stabilities, + } + + def setup_shape_test_results_cache(self, obj_handle: str, shape_id: str) -> None: + """ + Ensure the 'shape_test_results' sub-cache is initialized for a 'shape_id'. + """ + if shape_id not in self.gt_data[obj_handle]["shape_test_results"]: + self.gt_data[obj_handle]["shape_test_results"][shape_id] = { + "settle_report": {}, + "sphere_shake_report": {}, + "collision_grid_report": {}, + } + + def run_physics_settle_test(self, obj_handle): + """ + Drops the object on a plane and waits for it to sleep. + Provides a heuristic measure of dynamic stability. If the object jitters, bounces, or oscillates it won't sleep. + """ + + cfg = self.get_cfg_with_mm() + with habitat_sim.Simulator(cfg) as sim: + rom = sim.get_rigid_object_manager() + obj = rom.add_object_by_template_handle(obj_handle) + assert obj.is_alive, "Object was not added correctly." + + # when evaluating multiple proxy shapes, need unique ids: + shape_id = self.get_proxy_shape_id(obj_handle) + self.setup_shape_test_results_cache(obj_handle, shape_id) + + # add a plane + otm = sim.get_object_template_manager() + cube_plane_handle = "cubePlaneSolid" + if not otm.get_library_has_handle(cube_plane_handle): + cube_prim_handle = otm.get_template_handles("cubeSolid")[0] + cube_template = otm.get_template_by_handle(cube_prim_handle) + cube_template.scale = mn.Vector3(20, 0.05, 20) + otm.register_template(cube_template, cube_plane_handle) + assert otm.get_library_has_handle(cube_plane_handle) + plane_obj = rom.add_object_by_template_handle(cube_plane_handle) + assert plane_obj.is_alive, "Plane object was not added correctly." + plane_obj.motion_type = habitat_sim.physics.MotionType.STATIC + + # use DebugVisualizer to get 6-axis view of the object + dvb: Optional[hab_debug_vis.DebugVisualizer] = None + if self.generate_debug_images and self.output_directory is not None: + dvb = hab_debug_vis.DebugVisualizer( + sim=sim, + output_path=self.output_directory, + default_sensor_uuid="color_sensor", + ) + dvb.peek_rigid_object( + obj, + peek_all_axis=True, + additional_savefile_prefix=f"plane_snap_{shape_id}_", + ) + + # snap the object to the plane + obj_col_bb = obj.collision_shape_aabb + obj.translation = mn.Vector3(0, obj_col_bb.max[1] - obj_col_bb.min[1], 0) + success = snap_down(sim, obj, support_obj_ids=[plane_obj.object_id]) + + if not success: + print("Failed to snap object to plane...") + self.gt_data[obj_handle]["shape_test_results"][shape_id][ + "settle_report" + ] = { + "success": False, + "realtime": "NA", + "max_time": "NA", + "settle_time": "NA", + } + return + + # simulate for settling + max_sim_time = 5.0 + dt = 0.25 + real_start_time = time.time() + object_is_stable = False + start_time = sim.get_world_time() + while sim.get_world_time() - start_time < max_sim_time: + sim.step_world(dt) + # dvb.peek_rigid_object( + # obj, + # peek_all_axis=True, + # additional_savefile_prefix=f"plane_snap_{sim.get_world_time() - start_time}_", + # ) + + if not obj.awake: + object_is_stable = True + # the object has settled, no need to continue simulating + break + real_test_time = time.time() - real_start_time + sim_settle_time = sim.get_world_time() - start_time + print(f"Physics Settle Time Report: '{obj_handle}'") + if object_is_stable: + print(f" Settled in {sim_settle_time} sim seconds.") + else: + print(f" Failed to settle in {max_sim_time} sim seconds.") + print(f" Test completed in {real_test_time} seconds.") + + self.gt_data[obj_handle]["shape_test_results"][shape_id][ + "settle_report" + ] = { + "success": object_is_stable, + "realtime": real_test_time, + "max_time": max_sim_time, + "settle_time": sim_settle_time, + } + + def compute_grid_collision_times(self, obj_handle, subdivisions=0, use_gt=False): + """ + Runs a collision test over a subdivided grid of box shapes within the object's AABB. + Measures discrete collision check efficiency. + + "param subdivisions": number of recursive subdivisions to create the grid. E.g. 0 is the bb, 1 is 8 box of 1/2 bb size, etc... + """ + + scene_name = "NONE" + if use_gt: + scene_name = self.gt_data[obj_handle]["stage_template_name"] + cfg = self.get_cfg_with_mm(scene=scene_name) + with habitat_sim.Simulator(cfg) as sim: + rom = sim.get_rigid_object_manager() + shape_id = "gt" + shape_bb = None + if not use_gt: + obj = rom.add_object_by_template_handle(obj_handle) + assert obj.is_alive, "Object was not added correctly." + # need to make the object STATIC so it doesn't move + obj.motion_type = habitat_sim.physics.MotionType.STATIC + # when evaluating multiple proxy shapes, need unique ids: + shape_id = self.get_proxy_shape_id(obj_handle) + shape_bb = obj.root_scene_node.cumulative_bb + else: + shape_bb = sim.get_active_scene_graph().get_root_node().cumulative_bb + + self.setup_shape_test_results_cache(obj_handle, shape_id) + + # add the collision box + otm = sim.get_object_template_manager() + cube_prim_handle = otm.get_template_handles("cubeSolid")[0] + cube_template = otm.get_template_by_handle(cube_prim_handle) + num_segments = 2**subdivisions + subdivision_scale = 1.0 / (num_segments) + cube_template.scale = shape_bb.size() * subdivision_scale + # TODO: test this scale + otm.register_template(cube_template, "cubeTestSolid") + + test_obj = rom.add_object_by_template_handle("cubeTestSolid") + assert test_obj.is_alive, "Test box object was not added correctly." + + cell_scale = cube_template.scale + # run the grid test + test_start_time = time.time() + max_col_time = 0 + for x in range(num_segments): + for y in range(num_segments): + for z in range(num_segments): + box_center = ( + shape_bb.min + + mn.Vector3.x_axis(cell_scale[0]) * x + + mn.Vector3.y_axis(cell_scale[1]) * y + + mn.Vector3.z_axis(cell_scale[2]) * z + + cell_scale / 2.0 + ) + test_obj.translation = box_center + col_start = time.time() + test_obj.contact_test() + col_time = time.time() - col_start + max_col_time = max(max_col_time, col_time) + total_test_time = time.time() - test_start_time + avg_test_time = total_test_time / (num_segments**3) + + print( + f"Physics grid collision test report: {obj_handle}. {subdivisions} subdivisions." + ) + print( + f" Test took {total_test_time} seconds for {num_segments**3} collision tests." + ) + + # TODO: test this + + self.gt_data[obj_handle]["shape_test_results"][shape_id][ + "collision_grid_report" + ][subdivisions] = { + "total_col_time": total_test_time, + "avg_col_time": avg_test_time, + "max_col_time": max_col_time, + } + + def run_physics_sphere_shake_test(self, obj_handle): + """ + Places the DYNAMIC object in a sphere with other primitives and varies gravity to mix the objects. + Per-frame physics compute time serves as a metric for dynamic simulation efficiency. + """ + + # prepare a sphere stage + sphere_radius = self.gt_data[obj_handle]["scene_bb"].size().length() * 1.5 + sphere_stage_handle = "sphereTestStage" + stm = self.mm.stage_template_manager + sphere_template = stm.create_new_template(sphere_stage_handle) + sphere_template.render_asset_handle = "data/test_assets/objects/sphere.glb" + sphere_template.scale = mn.Vector3(sphere_radius * 2.0) # glb is radius 0.5 + stm.register_template(sphere_template, sphere_stage_handle) + + # prepare the test sphere object + otm = self.mm.object_template_manager + sphere_test_handle = "sphereTestCollisionObject" + sphere_prim_handle = otm.get_template_handles("sphereSolid")[0] + sphere_template = otm.get_template_by_handle(sphere_prim_handle) + test_sphere_radius = sphere_radius / 100.0 + sphere_template.scale = mn.Vector3(test_sphere_radius) + otm.register_template(sphere_template, sphere_test_handle) + assert otm.get_library_has_handle(sphere_test_handle) + + shape_id = self.get_proxy_shape_id(obj_handle) + self.setup_shape_test_results_cache(obj_handle, shape_id) + + cfg = self.get_cfg_with_mm(scene=sphere_stage_handle) + with habitat_sim.Simulator(cfg) as sim: + rom = sim.get_rigid_object_manager() + obj = rom.add_object_by_template_handle(obj_handle) + assert obj.is_alive, "Object was not added correctly." + + # fill the remaining space with small spheres + num_spheres = 0 + while num_spheres < 100: + sphere_obj = rom.add_object_by_template_handle(sphere_test_handle) + assert sphere_obj.is_alive, "Object was not added correctly." + num_tries = 0 + while num_tries < 50: + num_tries += 1 + # sample point + new_point = mn.Vector3(np.random.random(3) * 2.0 - np.ones(1)) + while new_point.length() >= 0.99: + new_point = mn.Vector3(np.random.random(3) * 2.0 - np.ones(1)) + sphere_obj.translation = new_point + if not sphere_obj.contact_test(): + num_spheres += 1 + break + if num_tries == 50: + # we hit our max, so end the search + rom.remove_object_by_handle(sphere_obj.handle) + break + + # run the simulation for timing + gravity = sim.get_gravity() + grav_rotation_rate = 0.5 # revolutions per second + max_sim_time = 10.0 + dt = 0.25 + real_start_time = time.time() + start_time = sim.get_world_time() + while sim.get_world_time() - start_time < max_sim_time: + sim.step_world(dt) + # change gravity + cur_time = sim.get_world_time() - start_time + grav_revolutions = grav_rotation_rate * cur_time + # rotate the gravity vector around the Z axis + g_quat = mn.Quaternion.rotation( + mn.Rad(grav_revolutions * mn.math.pi * 2), mn.Vector3(0, 0, 1) + ) + sim.set_gravity(g_quat.transform_vector(gravity)) + + real_test_time = time.time() - real_start_time + + print(f"Physics 'sphere shake' report: {obj_handle}") + print( + f" {num_spheres} spheres took {real_test_time} seconds for {max_sim_time} sim seconds." + ) + + self.gt_data[obj_handle]["shape_test_results"][shape_id][ + "sphere_shake_report" + ] = { + "realtime": real_test_time, + "sim_time": max_sim_time, + "num_spheres": num_spheres, + } + + def compute_gt_errors(self, obj_handle: str) -> None: + """ + Compute and cache all ground truth error metrics. + Assumes `self.gt_data[obj_handle]["raycasts"]` keys are different raycast results to be compared. + 'gt' must exist. + """ + + assert ( + obj_handle in self.gt_data + ), f"`{obj_handle}` does not have any entry in gt_data: {self.gt_data.keys()}. Call to `setup_obj_gt(obj_handle)` required." + assert ( + len(self.gt_data[obj_handle]["raycasts"]) > 1 + ), "Only gt results acquired, no error to compute. Try `compute_proxy_metrics` or `compute_baseline_metrics`." + assert ( + "gt" in self.gt_data[obj_handle]["raycasts"] + ), "Must have a ground truth to compare against. Should be generated in `setup_obj_gt(obj_handle)`." + + for shape_id in self.gt_data[obj_handle]["raycasts"]: + self.setup_shape_test_results_cache(obj_handle, shape_id) + if ( + shape_id != "gt" + and "normalized_raycast_error" + not in self.gt_data[obj_handle]["shape_test_results"][shape_id] + ): + normalized_error = get_raycast_results_cumulative_error_metric( + self.gt_data[obj_handle]["raycasts"]["gt"], + self.gt_data[obj_handle]["raycasts"][shape_id], + ) + self.gt_data[obj_handle]["shape_test_results"][shape_id][ + "normalized_raycast_error" + ] = normalized_error + + def get_obj_render_mesh_filepath(self, obj_template_handle: str): + """ + Return the filepath of the render mesh for an object. + """ + otm = self.mm.object_template_manager + obj_template = otm.get_template_by_handle(obj_template_handle) + assert obj_template is not None, "Object template is not registerd." + return os.path.abspath(obj_template.render_asset_handle) + + def permute_param_variations( + self, param_ranges: Dict[str, List[Any]] + ) -> List[List[Any]]: + """ + Generate a list of all permutations of the provided parameter ranges defined in a Dict. + """ + permutations = [[]] + + # permute variations + for attr, values in param_ranges.items(): + new_permutations = [] + for v in values: + for permutation in permutations: + extended_permutation = [(attr, v)] + for setting in permutation: + extended_permutation.append(setting) + new_permutations.append(extended_permutation) + permutations = new_permutations + print(f"Parameter permutations = {len(permutations)}") + for setting in permutations: + print(f" {setting}") + + return permutations + + def run_coacd_grid_search( + self, + obj_template_handle: str, + param_range_override: Optional[Dict[str, List[Any]]] = None, + ) -> None: + """ + Run grid search on relevant COACD params for an object. + """ + + # Parameter tuning tricks from https://github.com/SarahWeiii/CoACD in definition of COACDParams. + + param_ranges = { + "threshold": [0.04, 0.01], + } + + if param_range_override is not None: + param_ranges = param_range_override + + permutations = self.permute_param_variations(param_ranges) + + coacd_start_time = time.time() + coacd_iteration_times = {} + coacd_num_hulls = {} + # evaluate COACD settings + for setting in permutations: + coacd_param = COACDParams() + setting_string = "" + for attr, val in setting: + setattr(coacd_param, attr, val) + setting_string += f" '{attr}'={val}" + + self.increment_proxy_index(obj_template_handle) + shape_id = self.get_proxy_shape_id(obj_template_handle) + + coacd_iteration_time = time.time() + output_file, num_hulls = self.run_coacd( + obj_template_handle, coacd_param + ) + + # setup the proxy + otm = self.mm.object_template_manager + obj_template = otm.get_template_by_handle(obj_template_handle) + obj_template.collision_asset_handle = output_file + otm.register_template(obj_template) + + if "coacd_settings" not in self.gt_data[obj_template_handle]: + self.gt_data[obj_template_handle]["coacd_settings"] = {} + self.gt_data[obj_template_handle]["coacd_settings"][shape_id] = ( + coacd_param, + setting_string, + ) + # store the asset file for this shape_id + if "coacd_output_files" not in self.gt_data[obj_template_handle]: + self.gt_data[obj_template_handle]["coacd_output_files"] = {} + self.gt_data[obj_template_handle]["coacd_output_files"][ + shape_id + ] = output_file + + self.compute_proxy_metrics(obj_template_handle) + # self.compute_grid_collision_times(obj_template_handle, subdivisions=1) + # self.run_physics_settle_test(obj_template_handle) + # self.run_physics_sphere_shake_test(obj_template_handle) + if self.compute_receptacle_useability_metrics: + self.compute_receptacle_access_metrics( + obj_handle=obj_template_handle, use_gt=False + ) + self.compute_receptacle_stability( + obj_handle=obj_template_handle, use_gt=False + ) + coacd_iteration_times[shape_id] = time.time() - coacd_iteration_time + coacd_num_hulls[shape_id] = num_hulls + + print(f"Total CAOCD time = {time.time()-coacd_start_time}") + print(" Iteration times = ") + for shape_id, settings in self.gt_data[obj_template_handle][ + "coacd_settings" + ].items(): + print( + f" {shape_id} - {settings[1]} - {coacd_iteration_times[shape_id]}" + ) + + def run_coacd( + self, + obj_template_handle: str, + params: COACDParams, + output_file: Optional[str] = None, + ) -> str: + """ + Run COACD on an object given a set of parameters producing a file. + If output_file is not provided, defaults to "COACD_output/obj_name.glb" where obj_name is truncated handle (filename, no path or file ending). + """ + assert ( + coacd_imported + ), "coacd is not installed. Linux only: 'pip install coacd'." + if output_file is None: + obj_name = obj_template_handle.split(".object_config.json")[0].split("/")[ + -1 + ] + output_file = ( + "COACD_output/" + + obj_name + + "_" + + self.get_proxy_shape_id(obj_template_handle) + + ".glb" + ) + os.makedirs(os.path.dirname(output_file), exist_ok=True) + input_filepath = self.get_obj_render_mesh_filepath(obj_template_handle) + # TODO: this seems dirty, maybe refactor: + tris = trimesh.load(input_filepath).triangles + verts = [] + indices = [] + v_counter = 0 + for tri in tris: + indices.append([v_counter, v_counter + 1, v_counter + 2]) + v_counter += 3 + for vert in tri: + verts.append(vert) + imesh = coacd.Mesh() + imesh.vertices = verts + imesh.indices = indices + parts = coacd.run_coacd( + imesh, + threshold=params.threshold, + max_convex_hull=params.max_convex_hull, + preprocess=params.preprocess, + preprocess_resolution=params.preprocess_resolution, + mcts_nodes=params.mcts_nodes, + mcts_iterations=params.mcts_iterations, + mcts_max_depth=params.mcts_max_depth, + pca=params.pca, + merge=params.merge, + seed=params.seed, + ) + mesh_parts = [ + trimesh.Trimesh(np.array(p.vertices), np.array(p.indices).reshape((-1, 3))) + for p in parts + ] + scene = trimesh.Scene() + + np.random.seed(0) + for p in mesh_parts: + p.visual.vertex_colors[:, :3] = (np.random.rand(3) * 255).astype(np.uint8) + scene.add_geometry(p) + scene.export(output_file) + return output_file, len(parts) + + def compute_shape_score(self, obj_h: str, shape_id: str) -> float: + """ + Compute the shape score for the given object and shape_id. + Higher shape score is better performance on the metrics. + """ + shape_score = 0 + + # start with normalized error + normalized_error = self.gt_data[obj_h]["shape_test_results"][shape_id][ + "normalized_raycast_error" + ] + shape_score -= normalized_error + + # sum up scores for al receptacles + for _rec_name, rec_data in self.gt_data[obj_h]["receptacles"].items(): + sh_rec_dat = rec_data["shape_id_results"][shape_id] + gt_rec_dat = rec_data["shape_id_results"]["gt"] + gt_access = gt_rec_dat["access_results"]["receptacle_access_score"] + gt_stability = gt_rec_dat["stability_results"]["success_ratio"] + + # filter out generally bad receptacles from the score + if gt_access < 0.15 or gt_stability < 0.5: + "this receptacle is not good anyway, so skip it" + continue + + # penalize different acces than ground truth (more access than gt is also bad as it implies worse overall shape matching) + rel_access_score = abs( + gt_access - sh_rec_dat["access_results"]["receptacle_access_score"] + ) + shape_score -= rel_access_score + + # penalize stability directly (more stability than ground truth is not a problem) + stability_ratio = sh_rec_dat["stability_results"]["success_ratio"] + shape_score += stability_ratio + + return shape_score + + def optimize_object_col_shape( + self, + obj_h: str, + col_shape_dir: Optional[str] = None, + method="coacd", + param_range_override: Optional[Dict[str, List[Any]]] = None, + ): + """ + Run COACD optimization for a specific object. + Identify the optimal collision shape and save the result as the new default. + + :return: Tuple(best_shape_id, best_shape_score, original_shape_score) if best_shape_id == "pr0", then optimization didn't change anything. + """ + otm = self.mm.object_template_manager + obj_temp = otm.get_template_by_handle(obj_h) + cur_col_shape_path = os.path.abspath(obj_temp.collision_asset_handle) + self.setup_obj_gt(obj_h) + self.compute_proxy_metrics(obj_h) + self.compute_receptacle_stability(obj_h, use_gt=True) + self.compute_receptacle_stability(obj_h) + self.compute_receptacle_access_metrics(obj_h, use_gt=True) + self.compute_receptacle_access_metrics(obj_h, use_gt=False) + if method == "coacd": + self.run_coacd_grid_search(obj_h, param_range_override) + self.compute_gt_errors(obj_h) + + # time to select the best version + best_shape_id = "pr0" + pr0_shape_score = self.compute_shape_score(obj_h, "pr0") + settings_key = method + "_settings" + best_shape_score = pr0_shape_score + shape_scores = {} + for shape_id in self.gt_data[obj_h][settings_key]: + shape_score = self.compute_shape_score(obj_h, shape_id) + shape_scores[shape_id] = shape_score + # we only want significantly better shapes (10% or 0.1 score better threshold) + if ( + shape_score > (best_shape_score + abs(best_shape_score) * 0.1) + and shape_score - best_shape_score > 0.1 + ): + best_shape_id = shape_id + best_shape_score = shape_score + + print(self.gt_data[obj_h][settings_key]) + print(shape_scores) + + if best_shape_id != "pr0": + # re-save the best version + print( + f"Best shape_id = {best_shape_id} with shape score {best_shape_score} better than 'pr0' with shape score {pr0_shape_score}." + ) + # copy the collision asset into the dataset directory + if method == "coacd": + asset_file = self.gt_data[obj_h]["coacd_output_files"][best_shape_id] + os.system(f"cp {asset_file} {cur_col_shape_path}") + else: + print( + f"Best shape_id = {best_shape_id} with shape score {best_shape_score}." + ) + + best_shape_params = None + if best_shape_id != "pr0": + best_shape_params = self.gt_data[obj_h][settings_key][best_shape_id] + + # self.cache_global_results() + self.clean_obj_gt(obj_h) + # then save results to file + # self.save_results_to_csv("cpo_out") + return (best_shape_id, best_shape_score, pr0_shape_score, best_shape_params) + + def cache_global_results(self) -> None: + """ + Cache the current global cumulative results. + Do this after an object's computation is done (compute_gt_errors) before cleaning the gt data. + """ + + for obj_handle in self.gt_data: + # populate the high-level sub-cache definitions + if obj_handle not in self.results: + self.results[obj_handle] = { + "shape_metrics": {}, + "receptacle_metrics": {}, + } + # populate the per-shape metric sub-cache + for shape_id, shape_results in self.gt_data[obj_handle][ + "shape_test_results" + ].items(): + if shape_id == "gt": + continue + self.results[obj_handle]["shape_metrics"][shape_id] = {"col_grid": {}} + sm = self.results[obj_handle]["shape_metrics"][shape_id] + if "normalized_raycast_error" in shape_results: + sm["normalized_raycast_error"] = shape_results[ + "normalized_raycast_error" + ] + if len(shape_results["settle_report"]) > 0: + sm["settle_success"] = shape_results["settle_report"]["success"] + sm["settle_time"] = shape_results["settle_report"]["settle_time"] + sm["settle_max_step_time"] = shape_results["settle_report"][ + "max_time" + ] + sm["settle_realtime"] = shape_results["settle_report"]["realtime"] + if len(shape_results["sphere_shake_report"]) > 0: + sm["shake_simtime"] = shape_results["sphere_shake_report"][ + "sim_time" + ] + sm["shake_realtime"] = shape_results["sphere_shake_report"][ + "realtime" + ] + sm["shake_num_spheres"] = shape_results["sphere_shake_report"][ + "num_spheres" + ] + if len(shape_results["collision_grid_report"]) > 0: + for subdiv, col_subdiv_results in shape_results[ + "collision_grid_report" + ].items(): + sm["col_grid"][subdiv] = { + "total_time": col_subdiv_results["total_col_time"], + "avg_time": col_subdiv_results["avg_col_time"], + "max_time": col_subdiv_results["max_col_time"], + } + # populate the receptacle metric sub-cache + for rec_name, rec_data in self.gt_data[obj_handle]["receptacles"].items(): + self.results[obj_handle]["receptacle_metrics"][rec_name] = {} + for shape_id, shape_data in rec_data["shape_id_results"].items(): + self.results[obj_handle]["receptacle_metrics"][rec_name][ + shape_id + ] = {} + rsm = self.results[obj_handle]["receptacle_metrics"][rec_name][ + shape_id + ] + if "stability_results" in shape_data: + rsm["stability_success_ratio"] = shape_data[ + "stability_results" + ]["success_ratio"] + rsm["failed_snap"] = shape_data["stability_results"][ + "failed_snap" + ] + rsm["failed_by_distance"] = shape_data["stability_results"][ + "failed_by_distance" + ] + rsm["failed_unstable"] = shape_data["stability_results"][ + "failed_unstable" + ] + rsm["total"] = shape_data["stability_results"]["total"] + if "access_results" in shape_data: + rsm["receptacle_access_score"] = shape_data["access_results"][ + "receptacle_access_score" + ] + rsm["receptacle_access_rate"] = shape_data["access_results"][ + "receptacle_access_rate" + ] + + def save_results_to_csv(self, filename: str) -> None: + """ + Save current global results to a csv file in the self.output_directory. + """ + + assert len(self.results) > 0, "There must be results to save." + + assert ( + self.output_directory is not None + ), "Must have an output directory to save." + + import csv + + filepath = os.path.join(self.output_directory, filename) + + # first collect all active metrics to log + active_subdivs = [] + active_shape_metrics = [] + for _obj_handle, obj_results in self.results.items(): + for _shape_id, shape_results in obj_results["shape_metrics"].items(): + for metric in shape_results: + if metric == "col_grid": + for subdiv in shape_results["col_grid"]: + if subdiv not in active_subdivs: + active_subdivs.append(subdiv) + else: + if metric not in active_shape_metrics: + active_shape_metrics.append(metric) + active_subdivs = sorted(active_subdivs) + + # save shape metric csv + with open(filepath + ".csv", "w") as f: + writer = csv.writer(f, quoting=csv.QUOTE_ALL) + # first collect all column names (metrics): + existing_cols = ["object_handle|shape_id"] + existing_cols.extend(active_shape_metrics) + for subdiv in active_subdivs: + existing_cols.append(f"col_grid_{subdiv}_total_time") + existing_cols.append(f"col_grid_{subdiv}_avg_time") + existing_cols.append(f"col_grid_{subdiv}_max_time") + # write column names row + writer.writerow(existing_cols) + + # write results rows + for obj_handle, obj_results in self.results.items(): + for shape_id, shape_results in obj_results["shape_metrics"].items(): + row_data = [obj_handle + "|" + shape_id] + for metric_key in active_shape_metrics: + if metric_key in shape_results: + row_data.append(shape_results[metric_key]) + else: + row_data.append("") + for subdiv in active_subdivs: + if subdiv in shape_results["col_grid"]: + row_data.append( + shape_results["col_grid"][subdiv]["total_time"] + ) + row_data.append( + shape_results["col_grid"][subdiv]["avg_time"] + ) + row_data.append( + shape_results["col_grid"][subdiv]["max_time"] + ) + else: + for _ in range(3): + row_data.append("") + writer.writerow(row_data) + + # collect active receptacle metrics + active_rec_metrics = [] + for _obj_handle, obj_results in self.results.items(): + for _rec_name, rec_results in obj_results["receptacle_metrics"].items(): + for _shape_id, shape_results in rec_results.items(): + for metric in shape_results: + if metric not in active_rec_metrics: + active_rec_metrics.append(metric) + + # export receptacle metrics to CSV + if self.compute_receptacle_useability_metrics: + rec_filepath = filepath + "_receptacle_metrics" + with open(rec_filepath + ".csv", "w") as f: + writer = csv.writer(f, quoting=csv.QUOTE_ALL) + # first collect all column names: + existing_cols = ["obj_handle|receptacle|shape_id"] + existing_cols.extend(active_rec_metrics) + + # write column names row + writer.writerow(existing_cols) + + # write results rows + for obj_handle, obj_results in self.results.items(): + for rec_name, rec_results in obj_results[ + "receptacle_metrics" + ].items(): + for shape_id, shape_results in rec_results.items(): + row_data = [obj_handle + "|" + rec_name + "|" + shape_id] + for metric_key in active_rec_metrics: + if metric_key in shape_results: + row_data.append(shape_results[metric_key]) + else: + row_data.append("") + # write row data + writer.writerow(row_data) + + def compute_and_save_results_for_objects( + self, obj_handle_substrings: List[str], output_filename: str = "cpo_out" + ) -> None: + # first find all full object handles + otm = self.mm.object_template_manager + obj_handles = [] + for obj_h in obj_handle_substrings: + # find the full handle + matching_obj_handles = otm.get_file_template_handles(obj_h) + assert ( + len(matching_obj_handles) == 1 + ), f"None or many matching handles to substring `{obj_h}`: {matching_obj_handles}" + obj_handles.append(matching_obj_handles[0]) + + print(f"Found handles: {obj_handles}.") + print("Computing metrics:") + # then compute metrics for all objects and cache + for obix, obj_h in enumerate(obj_handles): + print("-------------------------------") + print(f"Computing metric for `{obj_h}`, {obix}|{len(obj_handles)}") + print("-------------------------------") + self.setup_obj_gt(obj_h) + # self.compute_baseline_metrics(obj_h) + self.compute_proxy_metrics(obj_h) + + # physics tests + # self.run_physics_settle_test(obj_h) + # self.run_physics_sphere_shake_test(obj_h) + # self.compute_grid_collision_times(obj_h, subdivisions=0) + # self.compute_grid_collision_times(obj_h, subdivisions=1) + # self.compute_grid_collision_times(obj_h, subdivisions=2) + + # receptacle metrics: + if self.compute_receptacle_useability_metrics: + self.compute_receptacle_stability(obj_h, use_gt=True) + self.compute_receptacle_stability(obj_h) + print(" GT Receptacle Metrics:") + self.compute_receptacle_access_metrics(obj_h, use_gt=True) + print(" PR Receptacle Metrics:") + self.compute_receptacle_access_metrics(obj_h, use_gt=False) + self.compute_gt_errors(obj_h) + print_dict_structure(self.gt_data) + self.cache_global_results() + print_dict_structure(self.results) + self.clean_obj_gt(obj_h) + + # then save results to file + self.save_results_to_csv(output_filename) + + +def object_has_receptacles( + object_template_handle: str, + otm: habitat_sim.attributes_managers.ObjectAttributesManager, +) -> bool: + """ + Returns whether or not an object has a receptacle defined in its config file. + """ + # this prefix will be present for any entry which defines a receptacle + receptacle_prefix_string = "receptacle_" + + object_template = otm.get_template_by_handle(object_template_handle) + assert ( + object_template is not None + ), f"No template matching handle {object_template_handle}." + + user_cfg = object_template.get_user_config() + + return any( + sub_config_key.startswith(receptacle_prefix_string) + for sub_config_key in user_cfg.get_subconfig_keys() + ) + + +def get_objects_in_scene( + dataset_path: str, scene_handle: str, mm: habitat_sim.metadata.MetadataMediator +) -> List[str]: + """ + Load a scene and return a list of object template handles for all instantiated objects. + """ + sim_settings = default_sim_settings.copy() + sim_settings["scene_dataset_config_file"] = dataset_path + sim_settings["scene"] = scene_handle + sim_settings["default_agent_navmesh"] = False + + cfg = make_cfg(sim_settings) + cfg.metadata_mediator = mm + + with habitat_sim.Simulator(cfg) as sim: + scene_object_template_handles = [] + rom = sim.get_rigid_object_manager() + live_objects = rom.get_objects_by_handle_substring() + for _obj_handle, obj in live_objects.items(): + if obj.creation_attributes.handle not in scene_object_template_handles: + scene_object_template_handles.append(obj.creation_attributes.handle) + return scene_object_template_handles + + +def parse_object_orientations_from_metadata_csv( + metadata_csv: str, +) -> Dict[str, Tuple[mn.Vector3, mn.Vector3]]: + """ + Parse the 'up' and 'front' vectors of objects from a csv metadata file. + + :param metadata_csv: The absolute filepath of the metadata CSV. + + :return: A Dict mapping object ids to a Tuple of up, front vectors. + """ + + def str_to_vec(vec_str: str) -> mn.Vector3: + """ + Convert a list of 3 comma separated strings into a Vector3. + """ + elem_str = [float(x) for x in vec_str.split(",")] + assert len(elem_str) == 3, f"string '{vec_str}' must be a 3 vec." + return mn.Vector3(tuple(elem_str)) + + orientations = {} + + with open(metadata_csv, newline="") as csvfile: + reader = csv.reader(csvfile, delimiter=",") + id_row_ix = -1 + up_row_ix = -1 + front_row_ix = -1 + for rix, data_row in enumerate(reader): + if rix == 0: + id_row_ix = data_row.index("id") + up_row_ix = data_row.index("up") + front_row_ix = data_row.index("front") + else: + up = data_row[up_row_ix] + front = data_row[front_row_ix] + if len(up) == 0 or len(front) == 0: + # both must be set or neither + assert len(up) == 0 + assert len(front) == 0 + else: + orientations[data_row[id_row_ix]] = ( + str_to_vec(up), + str_to_vec(front), + ) + + return orientations + + +def correct_object_orientations( + obj_handles: List[str], + obj_orientations: Dict[str, Tuple[mn.Vector3, mn.Vector3]], + mm: habitat_sim.metadata.MetadataMediator, +) -> None: + """ + Correct the orientations for all object templates in 'obj_handles' as specified by 'obj_orientations'. + + :param obj_handles: A list of object template handles. + :param obj_orientations: A dict mapping object names (abridged, not handles) to Tuple of (up,front) orientation vectors. + """ + obj_handle_to_orientation = {} + for obj_name in obj_orientations: + for obj_handle in obj_handles: + if obj_name in obj_handle: + obj_handle_to_orientation[obj_handle] = obj_orientations[obj_name] + print(f"obj_handle_to_orientation = {obj_handle_to_orientation}") + for obj_handle, orientation in obj_handle_to_orientation.items(): + obj_template = mm.object_template_manager.get_template_by_handle(obj_handle) + obj_template.orient_up = orientation[0] + obj_template.orient_front = orientation[1] + mm.object_template_manager.register_template(obj_template) + + +def write_failure_ids( + failures: List[Tuple[int, str, str]], filename="failures_out.txt" +) -> None: + """ + Write handles from failure tuples to file for use as exclusion or for follow-up investigation. + """ + with open(filename, "w") as file: + for f in failures: + file.write(f[1]) + + +def main(): + parser = argparse.ArgumentParser( + description="Automate collision shape creation and validation." + ) + parser.add_argument("--dataset", type=str, help="path to SceneDataset.") + group = parser.add_mutually_exclusive_group(required=True) + group.add_argument( + "--scenes", type=str, nargs="+", help="one or more scenes to optimize." + ) + group.add_argument( + "--objects", type=str, nargs="+", help="one or more objects to optimize." + ) + group.add_argument( + "--all-rec-objects", + action="store_true", + help="Optimize all objects in the dataset with receptacles.", + ) + group.add_argument( + "--objects-file", + type=str, + help="optimize objects from a file containing object names separated by newline characters.", + ) + parser.add_argument( + "--start-ix", + default=-1, + type=int, + help="If optimizing all assets, provide a start index.", + ) + parser.add_argument( + "--end-ix", + default=-1, + type=int, + help="If optimizing all assets, provide an end index.", + ) + parser.add_argument( + "--parts-only", + action="store_true", + help="culls all objects without _part_ in the name.", + ) + parser.add_argument( + "--exclude", + type=str, + nargs="+", + help="one or more objects to exclude from optimization (e.g. if it inspires a crash in COACD).", + ) + parser.add_argument( + "--exclude-files", + type=str, + nargs="+", + help="provide one or more files with objects to exclude from optimization (NOTE: txt file with one id on each line, object names may include prefix 'fpModel.' which will be stripped.).", + ) + parser.add_argument( + "--output-dir", + type=str, + default="collision_shape_automation/", + help="output directory for saved csv and images. Default = `./collision_shape_automation/`.", + ) + parser.add_argument( + "--debug-images", + action="store_true", + help="turns on debug image output.", + ) + parser.add_argument( + "--export-fp-model-ids", + type=str, + help="Intercept optimization to output a txt file with model ids for online model categorizer view.", + ) + parser.add_argument( + "--coacd-thresholds", + type=float, + nargs="+", + help="one or more coacd thresholds [0-1] (lower is more detailed) to search. If not provided, default are [0.04, 0.01].", + ) + args = parser.parse_args() + + if not args.all_rec_objects: + assert ( + args.start_ix == -1 + ), "Can only provide start index for all objects optimization." + assert ( + args.end_ix == -1 + ), "Can only provide end index for all objects optimization." + + param_range_overrides = None + if args.coacd_thresholds: + param_range_overrides = { + "threshold": args.coacd_thresholds, + } + + sim_settings = default_sim_settings.copy() + sim_settings["scene_dataset_config_file"] = args.dataset + # necessary for debug rendering + sim_settings["sensor_height"] = 0 + sim_settings["width"] = 720 + sim_settings["height"] = 720 + sim_settings["clear_color"] = mn.Color4.magenta() * 0.5 + sim_settings["default_agent_navmesh"] = False + + # use the CollisionProxyOptimizer to compute metrics for multiple objects + cpo = CollisionProxyOptimizer(sim_settings, output_directory=args.output_dir) + cpo.generate_debug_images = args.debug_images + otm = cpo.mm.object_template_manager + + excluded_object_strings = [] + if args.exclude: + excluded_object_strings = args.exclude + if args.exclude_files: + for filepath in args.exclude_files: + assert os.path.exists(filepath) + with open(filepath, "r") as f: + lines = [line.strip().split("fpModel.")[-1] for line in f.readlines()] + excluded_object_strings.extend(lines) + excluded_object_strings = list(dict.fromkeys(excluded_object_strings)) + + # ---------------------------------------------------- + # specific object handle provided + if args.objects or args.all_rec_objects or args.objects_file: + assert ( + not args.export_fp_model_ids + ), "Feature not available for objects, only for scenes." + + unique_objects = None + + if args.objects: + # deduplicate the list + unique_objects = list(dict.fromkeys(args.objects)) + elif args.objects_file: + assert os.path.exists(args.objects_file) + with open(args.objects_file, "r") as f: + lines = [line.strip() for line in f.readlines()] + unique_objects = list(dict.fromkeys(lines)) + elif args.all_rec_objects: + objects_in_dataset = otm.get_file_template_handles() + rec_obj_in_dataset = [ + objects_in_dataset[i] + for i in range(len(objects_in_dataset)) + if object_has_receptacles(objects_in_dataset[i], otm) + ] + print( + f"Number of objects in dataset with receptacles = {len(rec_obj_in_dataset)}" + ) + unique_objects = rec_obj_in_dataset + + # validate the object handles + object_handles = [] + for object_name in unique_objects: + # get templates matches + matching_templates = otm.get_file_template_handles(object_name) + assert ( + len(matching_templates) > 0 + ), f"No matching templates in the dataset for '{object_name}'" + assert ( + len(matching_templates) == 1 + ), f"More than one matching template in the dataset for '{object_name}': {matching_templates}" + obj_h = matching_templates[0] + + # skip excluded objects + exclude_object = False + for ex_obj in excluded_object_strings: + if ex_obj in obj_h: + print(f"Excluding object {object_name}.") + exclude_object = True + break + if not exclude_object: + object_handles.append(obj_h) + + if args.parts_only: + object_handles = [obj_h for obj_h in object_handles if "_part_" in obj_h] + print(f"part objects only = {object_handles}") + + # optimize the objects + results = [] + failures = [] + start = args.start_ix if args.start_ix >= 0 else 0 + end = args.end_ix if args.end_ix >= 0 else len(object_handles) + assert end >= start, f"Start index ({start}) is lower than end index ({end})." + for obj_ix in range(start, end): + obj_h = object_handles[obj_ix] + print("+++++++++++++++++++++++++") + print("+++++++++++++++++++++++++") + print(f"Optimizing '{obj_h}' : {obj_ix} of {len(object_handles)}") + print("+++++++++++++++++++++++++") + try: + results.append( + cpo.optimize_object_col_shape( + obj_h, + method="coacd", + param_range_override=param_range_overrides, + ) + ) + print( + f"Completed optimization of '{obj_h}' : {obj_ix} of {len(object_handles)}" + ) + except Exception as err: + failures.append((obj_ix, obj_h, err)) + # display results + print("Object Optimization Results:") + for obj_h, obj_result in zip(object_handles, results): + print(f" {obj_h}: {obj_result}") + print("Failures:") + for f in failures: + print(f" {f}") + write_failure_ids(failures) + # ---------------------------------------------------- + + # ---------------------------------------------------- + # run the pipeline for a set of object parsed from a scene + if args.scenes: + scene_object_handles: Dict[str, List[str]] = {} + + # deduplicate the list + unique_scenes = list(dict.fromkeys(args.scenes)) + + # first validate the scene names have a unique match + scene_handles = cpo.mm.get_scene_handles() + for scene_name in unique_scenes: + matching_scenes = [h for h in scene_handles if scene_name in h] + assert ( + len(matching_scenes) > 0 + ), f"No scenes found matching provided scene name '{scene_name}'." + assert ( + len(matching_scenes) == 1 + ), f"More than one scenes found matching provided scene name '{scene_name}': {matching_scenes}." + + # collect all the objects for all the scenes in advance + for scene_name in unique_scenes: + objects_in_scene = get_objects_in_scene( + dataset_path=args.dataset, scene_handle=scene_name, mm=cpo.mm + ) + assert ( + len(objects_in_scene) > 0 + ), f"No objects found in scene '{scene_name}'. Are you sure this is a valid scene?" + + # skip excluded objects + included_objects = [] + for obj_h in objects_in_scene: + exclude_object = False + for ex_obj in excluded_object_strings: + if ex_obj in obj_h: + exclude_object = True + print(f"Excluding object {obj_h}.") + break + if not exclude_object: + included_objects.append(obj_h) + scene_object_handles[scene_name] = included_objects + + if args.export_fp_model_ids: + # intercept optimization to instead export a txt file with model ids for import into the model categorizer tool + with open(args.export_fp_model_ids, "w") as f: + aggregated_object_ids = [] + for scene_objects in scene_object_handles.values(): + rec_obj_in_scene = [ + scene_objects[i] + for i in range(len(scene_objects)) + if object_has_receptacles(scene_objects[i], otm) + ] + aggregated_object_ids.extend(rec_obj_in_scene) + aggregated_object_ids = list(dict.fromkeys(aggregated_object_ids)) + for obj_h in aggregated_object_ids: + obj_name = obj_h.split(".object_config.json")[0].split("/")[-1] + # TODO: this will change once the Model Categorizer supports these + if "_part_" not in obj_name: + f.write("fpModel." + obj_name + "\n") + print(f"Export fpModel ids to {args.export_fp_model_ids}") + exit() + + # optimize each scene + all_scene_results: Dict[ + str, Dict[str, List[Tuple[str, float, float, Any]]] + ] = {} + for scene, objects_in_scene in scene_object_handles.items(): + # clear and re-initialize the caches between scenes to prevent memory overflow on large batches. + cpo.init_caches() + + # ---------------------------------------------------- + # get a subset of objects with receptacles defined + rec_obj_in_scene = [ + objects_in_scene[i] + for i in range(len(objects_in_scene)) + if object_has_receptacles(objects_in_scene[i], otm) + ] + print( + f"Number of objects in scene '{scene}' with receptacles = {len(rec_obj_in_scene)}" + ) + # ---------------------------------------------------- + + # ---------------------------------------------------- + # load object orientation metadata + # BUG: Receptacles are not re-oriented by internal re-orientation transforms. Need to fix this... + # reorient_objects = False + # if reorient_objects: + # fp_models_metadata_file = ( + # "/home/alexclegg/Documents/dev/fphab/fpModels_metadata.csv" + # ) + # obj_orientations = parse_object_orientations_from_metadata_csv( + # fp_models_metadata_file + # ) + # correct_object_orientations(all_handles, obj_orientations, cpo.mm) + # ---------------------------------------------------- + + # run shape opt for all objects in the scene + scene_results: Dict[str, List[Tuple[str, float, float, Any]]] = {} + for obj_h in rec_obj_in_scene: + scene_results[obj_h] = cpo.optimize_object_col_shape( + obj_h, method="coacd", param_range_override=param_range_overrides + ) + + all_scene_results[scene] = scene_results + + print("------------------------------------") + print(f"Finished optimization of scene '{scene}': \n {scene_results}") + print("------------------------------------") + + print("==========================================") + print(f"Finished optimization of all scenes: \n {all_scene_results}") + print("==========================================") + + +if __name__ == "__main__": + main() diff --git a/tools/generate_blend_to_urdf_parser_report.py b/tools/generate_blend_to_urdf_parser_report.py new file mode 100644 index 0000000000..422928c775 --- /dev/null +++ b/tools/generate_blend_to_urdf_parser_report.py @@ -0,0 +1,222 @@ +# Copyright (c) Meta Platforms, Inc. and its affiliates. +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import argparse +import csv +import os +from typing import Callable, Dict, List + + +def file_endswith(filepath: str, end_str: str) -> bool: + """ + Return whether or not the file ends with a string. + """ + return filepath.endswith(end_str) + + +def find_files( + root_dir: str, discriminator: Callable[[str, str], bool], disc_str: str +) -> List[str]: + """ + Recursively find all filepaths under a root directory satisfying a particular constraint as defined by a discriminator function. + + :param root_dir: The roor directory for the recursive search. + :param discriminator: The discriminator function which takes a filepath and discriminator string and returns a bool. + + :return: The list of all absolute filepaths found satisfying the discriminator. + """ + filepaths: List[str] = [] + + if not os.path.exists(root_dir): + print(" Directory does not exist: " + str(root_dir)) + return filepaths + + for entry in os.listdir(root_dir): + entry_path = os.path.join(root_dir, entry) + if os.path.isdir(entry_path): + sub_dir_filepaths = find_files(entry_path, discriminator, disc_str) + filepaths.extend(sub_dir_filepaths) + # apply a user-provided discriminator function to cull filepaths + elif discriminator(entry_path, disc_str): + filepaths.append(entry_path) + return filepaths + + +def find_subdirectory_names(root_dir: str) -> List[str]: + """ + Lists all immediate child directories for a provided root directory. + """ + assert os.path.exists(root_dir) + + dirpaths = [] + + for entry in os.listdir(root_dir): + entry_path = os.path.join(root_dir, entry) + if os.path.isdir(entry_path): + dirpaths.append(entry) + + return dirpaths + + +def load_model_list_from_csv( + filepath: str, header_label: str = "Model ID" +) -> List[str]: + """ + Scrape a csv file for a list of model ids under a header label. + """ + assert filepath.endswith(".csv"), "This isn't a .csv file." + assert os.path.exists(filepath) + ids = [] + + with open(filepath, newline="") as f: + reader = csv.reader(f) + labels = [] + id_column = None + for rix, row in enumerate(reader): + if rix == 0: + labels = row + id_column = labels.index(header_label) + else: + # allow empty cells to keep consistency with row ordering in the sheet (for copy/paste) + ids.append(row[id_column]) + + return ids + + +# ----------------------------------------- +# Generates a report checking the success of blend to urdf parsing batches. +# e.g. python tools/generate_blend_to_urdf_parser_report.py --root-dir --report-model-list /all_scenes_artic_models-M1.csv +# e.g. add " --report-filepath 4, "Must provided more than the filetype." + report_filepath = args.report_filepath + + # scrape all existing subdirectories + exported_folder_names = find_subdirectory_names(root_dir=root_dir) + exported_folder_claimed = [False for exported_folder_name in exported_folder_names] + + # get model ids list + model_ids = exported_folder_names + if args.report_model_list is not None: + model_ids = load_model_list_from_csv(filepath=args.report_model_list) + + # for each model ids, check for existance of each expected output + for model_id in model_ids: + folder_path = os.path.join(root_dir, model_id) + folder_exists = False + if model_id in exported_folder_names: + folder_exists = True + # NOTE: silly override to + elif model_id + ".glb" in exported_folder_names: + folder_path = folder_path + ".glb" + folder_exists = True + + if folder_exists: + exported_folder_claimed[ + exported_folder_names.index(folder_path.split("/")[-1]) + ] = True + + urdf_exists = len(find_files(folder_path, file_endswith, ".urdf")) > 0 + + config_exists = ( + len(find_files(folder_path, file_endswith, ".ao_config.json")) > 0 + ) + + # NOTE: there could be missing assets here, but without parsing the blend file again, we wouldn't know. Heuristic is to expect at least one. + num_rec_meshes = len( + find_files(folder_path, file_endswith, "_receptacle_mesh.glb") + ) + one_receptacle_exists = num_rec_meshes > 0 + + one_render_mesh_exists = ( + len(find_files(folder_path, file_endswith, ".glb")) - num_rec_meshes + ) > 0 + + parse_results_report[model_id] = [ + model_id, + folder_exists, + urdf_exists, + config_exists, + one_receptacle_exists, + one_render_mesh_exists, + ] + global_count["folder"] += int(not folder_exists) + global_count["config"] += int(not config_exists) + global_count["receptacles"] += int(not one_receptacle_exists) + global_count["urdf"] += int(not urdf_exists) + global_count["render_meshes"] += int(not one_render_mesh_exists) + else: + parse_results_report[model_id] = [False for i in range(len(cat_columns))] + parse_results_report[model_id][0] = model_id + for key in global_count: + global_count[key] += 1 + + # export results to a file + os.makedirs(os.path.dirname(report_filepath), exist_ok=True) + with open(report_filepath, "w", newline="") as f: + writer = csv.writer(f, delimiter=",", quotechar="|", quoting=csv.QUOTE_MINIMAL) + # write the header labels + writer.writerow(cat_columns) + # write the contents + for model_id in model_ids: + writer.writerow(parse_results_report[model_id]) + + print("-----------------------------------------------") + print(f"Wrote report to {report_filepath}.\n") + + print("The following folders were unclaimed. Likely the root node is misnamed:") + for folder_index, claimed in enumerate(exported_folder_claimed): + if not claimed: + print(f" {exported_folder_names[folder_index]}") + print("-----------------------------------------------") + + print(f"global_counts = {global_count}") + + +if __name__ == "__main__": + main() diff --git a/tools/remove_ssd_from_scene_instance.py b/tools/remove_ssd_from_scene_instance.py new file mode 100644 index 0000000000..0cdfcc4675 --- /dev/null +++ b/tools/remove_ssd_from_scene_instance.py @@ -0,0 +1,83 @@ +# Copyright (c) Meta Platforms, Inc. and its affiliates. +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import argparse +import json +import os +from typing import Callable, List + + +def file_is_scene_config(filepath: str) -> bool: + """ + Return whether or not the file is an scene_instance.json + """ + return filepath.endswith(".scene_instance.json") + + +def find_files(root_dir: str, discriminator: Callable[[str], bool]) -> List[str]: + """ + Recursively find all filepaths under a root directory satisfying a particular constraint as defined by a discriminator function. + + :param root_dir: The roor directory for the recursive search. + :param discriminator: The discriminator function which takes a filepath and returns a bool. + + :return: The list of all absolute filepaths found satisfying the discriminator. + """ + filepaths: List[str] = [] + + if not os.path.exists(root_dir): + print(" Directory does not exist: " + str(dir)) + return filepaths + + for entry in os.listdir(root_dir): + entry_path = os.path.join(root_dir, entry) + if os.path.isdir(entry_path): + sub_dir_filepaths = find_files(entry_path, discriminator) + filepaths.extend(sub_dir_filepaths) + # apply a user-provided discriminator function to cull filepaths + elif discriminator(entry_path): + filepaths.append(entry_path) + return filepaths + + +def remove_ssd_from_scene_instance_json(filepath: str): + """ + Strips any 'semantic_scene_instance' field from a scene_instance.json files and re-exports it. + """ + assert filepath.endswith(".scene_instance.json"), "Must be a scene instance JSON." + + file_is_modified = False + scene_conf = None + with open(filepath, "r") as f: + scene_conf = json.load(f) + if "semantic_scene_instance" in scene_conf: + scene_conf.pop("semantic_scene_instance") + file_is_modified = True + + # write the data as necessary + if file_is_modified and scene_conf is not None: + with open(filepath, "w") as f: + json.dump(scene_conf, f) + + +def main(): + parser = argparse.ArgumentParser( + description="Remove all 'semantic_scene_instance' fields from scene_instnace files in the dataset." + ) + parser.add_argument( + "--dataset-root-dir", + type=str, + help="path to HSSD SceneDataset root directory containing 'fphab-uncluttered.scene_dataset_config.json'.", + ) + args = parser.parse_args() + fp_root_dir = args.dataset_root_dir + config_root_dir = os.path.join(fp_root_dir, "scenes-uncluttered") + configs = find_files(config_root_dir, file_is_scene_config) + + for _ix, filepath in enumerate(configs): + remove_ssd_from_scene_instance_json(filepath) + + +if __name__ == "__main__": + main() diff --git a/tools/replace_articulated_models_in_rigid_scene.py b/tools/replace_articulated_models_in_rigid_scene.py new file mode 100644 index 0000000000..e8e87622d4 --- /dev/null +++ b/tools/replace_articulated_models_in_rigid_scene.py @@ -0,0 +1,214 @@ +# Copyright (c) Meta Platforms, Inc. and its affiliates. +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import argparse +import json +import os +from typing import Callable, List + + +def file_is_scene_config(filepath: str) -> bool: + """ + Return whether or not the file is an scene_instance.json + """ + return filepath.endswith(".scene_instance.json") + + +def file_is_urdf(filepath: str) -> bool: + """ + Return whether or not the file is a .urdf + """ + return filepath.endswith(".urdf") + + +def find_files(root_dir: str, discriminator: Callable[[str], bool]) -> List[str]: + """ + Recursively find all filepaths under a root directory satisfying a particular constraint as defined by a discriminator function. + + :param root_dir: The roor directory for the recursive search. + :param discriminator: The discriminator function which takes a filepath and returns a bool. + + :return: The list of all absolute filepaths found satisfying the discriminator. + """ + filepaths: List[str] = [] + + if not os.path.exists(root_dir): + print(" Directory does not exist: " + str(dir)) + return filepaths + + for entry in os.listdir(root_dir): + entry_path = os.path.join(root_dir, entry) + if os.path.isdir(entry_path): + sub_dir_filepaths = find_files(entry_path, discriminator) + filepaths.extend(sub_dir_filepaths) + # apply a user-provided discriminator function to cull filepaths + elif discriminator(entry_path): + filepaths.append(entry_path) + return filepaths + + +scenes_without_filters = [] + + +def find_and_replace_articulated_models_for_config( + filepath: str, + top_dir: str, + urdf_names: str, + src_dir: str = "scenes-uncluttered", + dest_dir: str = "scenes-articulated-uncluttered", +) -> None: + """ + For a given scene config, try to find a matching articulated objects for each rigid object. If found, add them to the config, replacing the rigid objects. + + :param top_dir: The top directory of the dataset from which the rec filter file path should be relative. + """ + assert filepath.endswith(".scene_instance.json"), "Must be a scene instance JSON." + scene_name = filepath.split(".scene_instance.json")[0].split("/")[-1] + + print(f"scene_name = {scene_name}") + + file_is_modified = False + with open(filepath, "r") as f: + scene_conf = json.load(f) + + ao_instance_data = [] + if "articulated_object_instances" in scene_conf: + ao_instance_data = scene_conf["articulated_object_instances"] + + modified_object_instance_data = [] + for object_instance_data in scene_conf["object_instances"]: + object_name = object_instance_data["template_name"] + + # look for a matching articulated object entry + urdf_name_match = None + for urdf_name in urdf_names: + if object_name in urdf_name: + urdf_name_match = urdf_name + break + + # create the modified JSON data + if urdf_name_match is None: + # add the object to the modified list + modified_object_instance_data.append(object_instance_data) + else: + file_is_modified = True + # TODO: all objects are non-uniformly scaled and won't fit exactly in the scenes... + # assert "non_uniform_scale" not in object_instance_data, "Rigid object is non-uniformaly scaled. Cannot replace with equivalent articulated object." + this_ao_instance_data = { + "template_name": urdf_name_match, + "translation_origin": "COM", + "fixed_base": True, + "motion_type": "DYNAMIC", + } + if "translation" in object_instance_data: + this_ao_instance_data["translation"] = object_instance_data[ + "translation" + ] + if "rotation" in object_instance_data: + this_ao_instance_data["rotation"] = object_instance_data["rotation"] + ao_instance_data.append(this_ao_instance_data) + + scene_conf["object_instances"] = modified_object_instance_data + scene_conf["articulated_object_instances"] = ao_instance_data + + if file_is_modified: + filepath = filepath.split(src_dir)[0] + dest_dir + filepath.split(src_dir)[-1] + with open(filepath, "w") as f: + json.dump(scene_conf, f, indent=4) + + +def main(): + parser = argparse.ArgumentParser( + description="Modify the scene_instance.json files, replacing rigid objects with articulated coutnerparts in a urdf/ directory." + ) + parser.add_argument( + "--dataset-root-dir", + type=str, + help="path to HSSD SceneDataset root directory containing 'fphab-uncluttered.scene_dataset_config.json'.", + ) + parser.add_argument( + "--src-dir", + type=str, + default="scenes-uncluttered", + help="Name of the source scene config directory within root-dir.", + ) + parser.add_argument( + "--dest-dir", + type=str, + default="scenes-articulated-uncluttered", + help="Name of the destination scene config directory within root-dir. Will be created if doesn't exist.", + ) + parser.add_argument( + "--scenes", + nargs="+", + type=str, + help="Substrings which indicate scenes which should be converted. When provided, only these scenes are converted.", + default=None, + ) + args = parser.parse_args() + fp_root_dir = args.dataset_root_dir + src_dir = args.src_dir + dest_dir = args.dest_dir + config_root_dir = os.path.join(fp_root_dir, src_dir) + configs = find_files(config_root_dir, file_is_scene_config) + urdf_dir = os.path.join(fp_root_dir, "urdf/") + urdf_files = find_files(urdf_dir, file_is_urdf) + + # create scene output directory + os.makedirs(os.path.join(fp_root_dir, dest_dir), exist_ok=True) + + invalid_urdf_files = [] + + # only consider urdf files with reasonable accompanying contents + def urdf_has_meshes_and_config(urdf_filepath: str) -> bool: + """ + Return whether or not there are render meshes and a config accompanying the urdf. + """ + if not os.path.exists(urdf_filepath.split(".urdf")[0] + ".ao_config.json"): + return False + has_render_mesh = False + for file_name in os.listdir(os.path.dirname(urdf_filepath)): + if file_name.endswith(".glb") and "receptacle" not in file_name: + has_render_mesh = True + break + return has_render_mesh + + valid_urdf_files = [ + urdf_file for urdf_file in urdf_files if urdf_has_meshes_and_config(urdf_file) + ] + + invalid_urdf_files = [ + urdf_file for urdf_file in urdf_files if urdf_file not in valid_urdf_files + ] + + urdf_names = [ + urdf_filename.split("/")[-1].split(".urdf")[0] + for urdf_filename in valid_urdf_files + ] + + # limit the scenes which are converted + if args.scenes is not None: + scene_limited_configs = [] + for scene in args.scenes: + for config in configs: + if scene + "." in config: + scene_limited_configs.append(config) + configs = list(set(scene_limited_configs)) + + for _ix, filepath in enumerate(configs): + find_and_replace_articulated_models_for_config( + filepath, + urdf_names=urdf_names, + top_dir=fp_root_dir, + src_dir=src_dir, + dest_dir=dest_dir, + ) + + print( + f"Migrated {len(valid_urdf_files)} urdfs into {len(configs)} scene configs. Invalid urdfs found and skipped ({len(invalid_urdf_files)}) = {invalid_urdf_files}" + ) + + +if __name__ == "__main__": + main() diff --git a/troublesome_object_ids.txt b/troublesome_object_ids.txt new file mode 100644 index 0000000000..22bd7ba04b --- /dev/null +++ b/troublesome_object_ids.txt @@ -0,0 +1,3 @@ +1d5a78b46d32bf41584c800a0dfa2536d7f0b395 +05980eee8561a3ebaf0753a2f14f5871611e693e +0928513ee59d54e84c3baef6fe2f6daa7c9339b3