diff --git a/python/rainbow/util/USD.py b/python/rainbow/util/USD.py index 031a395..ed9adfa 100644 --- a/python/rainbow/util/USD.py +++ b/python/rainbow/util/USD.py @@ -65,6 +65,30 @@ def set_mesh_positions(self, name: str, V: ArrayLike, time: float) -> None: raise ValueError(f'Mesh {name} does not exist') vertex_positions = Vt.Vec3fArray(V.tolist()) self.meshes[name].GetPointsAttr().Set(vertex_positions, time) + + def get_mesh_positions(self, name: str, time: float) -> ArrayLike: + """ Retrieve the positions of a mesh at a given timestamp. + + Args: + name (str): The name of the mesh. + time (float): The timestamp at which the positions should be retrieved. + + Returns: + ArrayLike: An array containing the vertex positions of the mesh. + + Raises: + ValueError: If the mesh does not exist in the scene, or if the mesh does not have positions set at the given timestamp. + """ + if name not in self.meshes: + raise ValueError(f'Mesh {name} does not exist') + + vertex_positions_attr = self.meshes[name].GetPointsAttr() + vertex_positions = vertex_positions_attr.Get(time) + + if vertex_positions: + return np.array(vertex_positions, dtype=np.float64) + else: + raise ValueError(f"No positions set for mesh {name} at time {time}") def set_animation_time(self, duration: float) -> None: """ Set the total animation time of the scene diff --git a/python/unit_tests/test_utils_USD.py b/python/unit_tests/test_utils_USD.py index 795b4c9..d72d353 100644 --- a/python/unit_tests/test_utils_USD.py +++ b/python/unit_tests/test_utils_USD.py @@ -32,10 +32,12 @@ def test_set_mesh_positions(self): self.usd_instance.add_mesh( self.sample_mesh_name, self.sample_vertex_positions, self.sample_triangle_faces) new_positions = np.array( - [[0.1, 0.1, 0.1], [1.1, 0.1, 0.1], [0.1, 1.1, 0.1]]) + [[0.1, 0.1, 0.1], [1.1, 0.1, 0.1], [0.1, 1.1, 0.1]], dtype=np.float64) time_stamp = 1.0 self.usd_instance.set_mesh_positions( self.sample_mesh_name, new_positions, time_stamp) + updated_positions = self.usd_instance.get_mesh_positions(self.sample_mesh_name, time_stamp) + self.assertTrue(np.allclose(new_positions, updated_positions, rtol=1e-5, atol=1e-10)) def test_set_mesh_positions_with_invalid_name(self): """ Test if the mesh does not exist, an error will be raised.