diff --git a/docs/source/scripts/generate_tfrecord.py b/docs/source/scripts/generate_tfrecord.py index caad456..944b9ec 100644 --- a/docs/source/scripts/generate_tfrecord.py +++ b/docs/source/scripts/generate_tfrecord.py @@ -80,15 +80,19 @@ def xml_to_csv(path): for xml_file in glob.glob(path + '/*.xml'): tree = ET.parse(xml_file) root = tree.getroot() + filename = root.find('filename').text + width = int(root.find('size').find('width').text) + height = int(root.find('size').find('height').text) for member in root.findall('object'): - value = (root.find('filename').text, - int(root.find('size')[0].text), - int(root.find('size')[1].text), - member[0].text, - int(member[4][0].text), - int(member[4][1].text), - int(member[4][2].text), - int(member[4][3].text) + bndbox = member.find('bndbox') + value = (filename, + width, + height, + member.find('name').text, + int(bndbox.find('xmin').text), + int(bndbox.find('ymin').text), + int(bndbox.find('xmax').text), + int(bndbox.find('ymax').text), ) xml_list.append(value) column_name = ['filename', 'width', 'height',