diff --git a/src/pycram/object_descriptors/mjcf.py b/src/pycram/object_descriptors/mjcf.py index d4200db1f..c5a9b95e4 100644 --- a/src/pycram/object_descriptors/mjcf.py +++ b/src/pycram/object_descriptors/mjcf.py @@ -1,3 +1,4 @@ +import os import pathlib import numpy as np @@ -5,6 +6,7 @@ from dm_control import mjcf from geometry_msgs.msg import Point from typing_extensions import Union, List, Optional, Dict, Tuple +from xml.etree import ElementTree as ET from ..datastructures.dataclasses import Color, VisualShape, BoxVisualShape, CylinderVisualShape, \ SphereVisualShape, MeshVisualShape @@ -238,6 +240,16 @@ class ObjectDescription(AbstractObjectDescription): A class that represents an object description of an object. """ + COMPILER_TAG = 'compiler' + """ + The tag of the compiler element in the MJCF file. + """ + MESH_DIR_ATTR = 'meshdir' + TEXTURE_DIR_ATTR = 'texturedir' + """ + The attributes of the compiler element in the MJCF file. The meshdir attribute is the directory where the mesh files + are stored and the texturedir attribute is the directory where the texture files are stored.""" + class Link(AbstractObjectDescription.Link, LinkDescription): ... @@ -367,8 +379,24 @@ def generate_from_mesh_file(self, path: str, name: str, color: Optional[Color] = factory.export_to_mjcf(output_file_path=save_path) def generate_from_description_file(self, path: str, save_path: str, make_mesh_paths_absolute: bool = True) -> None: - mjcf_model = mjcf.from_file(path) - self.write_description_to_file(mjcf_model, save_path) + model_str = self.replace_relative_paths_with_absolute_paths(path) + self.write_description_to_file(model_str, save_path) + + def replace_relative_paths_with_absolute_paths(self, model_path: str) -> str: + """ + Replace the relative paths in the xml file to be absolute paths. + + :param model_path: The path to the xml file. + """ + tree = ET.parse(model_path) + root = tree.getroot() + compiler = root.find(self.COMPILER_TAG) + model_dir = pathlib.Path(model_path).parent + for rel_dir_attrib in [self.MESH_DIR_ATTR, self.TEXTURE_DIR_ATTR]: + rel_dir = compiler.get(rel_dir_attrib) + abs_dir = str(pathlib.Path(os.path.join(model_dir, rel_dir)).resolve()) + compiler.set(rel_dir_attrib, abs_dir) + return ET.tostring(root, encoding='unicode', method='xml') def generate_from_parameter_server(self, name: str, save_path: str) -> None: mjcf_string = rospy.get_param(name) diff --git a/test/test_mjcf.py b/test/test_mjcf.py index bcb9e2262..edfbb842b 100644 --- a/test/test_mjcf.py +++ b/test/test_mjcf.py @@ -26,7 +26,6 @@ def setUpClass(cls): joint2 = body3.add('joint', name='joint2', type='slide') cls.model = MJCFObjDesc() - print(model.to_xml_string()) cls.model.update_description_from_string(model.to_xml_string()) def test_child_map(self): @@ -38,3 +37,4 @@ def test_parent_map(self): def test_get_chain(self): self.assertEqual(self.model.get_chain('body1', 'body3'), ['body1', 'joint1', 'body2', 'joint2', 'body3']) + diff --git a/test/test_multiverse.py b/test/test_multiverse.py index 4c1426211..859ff1c2a 100644 --- a/test/test_multiverse.py +++ b/test/test_multiverse.py @@ -53,6 +53,10 @@ def tearDownClass(cls): def tearDown(self): self.multiverse.remove_all_objects() + def test_spawn_xml_object(self): + bread = Object("bread_1", ObjectType.GENERIC_OBJECT, "bread_1.xml", pose=Pose([1, 1, 0.1])) + self.assert_poses_are_equal(bread.get_pose(), Pose([1, 1, 0.1])) + def test_spawn_mesh_object(self): milk = Object("milk", ObjectType.MILK, "milk.stl", pose=Pose([1, 1, 0.1])) self.assert_poses_are_equal(milk.get_pose(), Pose([1, 1, 0.1]))