Skip to content

Commit 5e5efba

Browse files
Add object detection repository
1 parent 45c5610 commit 5e5efba

File tree

5 files changed

+729
-0
lines changed

5 files changed

+729
-0
lines changed

object_detection/label_map_util.py

Lines changed: 191 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,191 @@
1+
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
# ==============================================================================
15+
16+
"""Label map utility functions."""
17+
18+
import logging
19+
from typing import Dict, List
20+
21+
from google.protobuf import text_format # type: ignore
22+
import string_int_label_map_pb2
23+
24+
25+
def _validate_label_map(label_map):
26+
"""Checks if a label map is valid.
27+
28+
Args:
29+
label_map: StringIntLabelMap to validate.
30+
31+
Raises:
32+
ValueError: if label map is invalid.
33+
"""
34+
for item in label_map.item:
35+
if item.id < 0:
36+
raise ValueError("Label map ids should be >= 0.")
37+
if (
38+
item.id == 0
39+
and item.name != "background"
40+
and item.display_name != "background"
41+
):
42+
raise ValueError("Label map id 0 is reserved for the background label")
43+
44+
45+
CategoryIndex = Dict[int, Dict]
46+
47+
48+
def create_category_index(categories: List[Dict]) -> CategoryIndex:
49+
"""Creates dictionary of COCO compatible categories keyed by category id.
50+
51+
Args:
52+
categories: a list of dicts, each of which has the following keys:
53+
'id': (required) an integer id uniquely identifying this category.
54+
'name': (required) string representing category name
55+
e.g., 'cat', 'dog', 'pizza'.
56+
57+
Returns:
58+
category_index: a dict containing the same entries as categories, but keyed
59+
by the 'id' field of each category.
60+
"""
61+
category_index = {}
62+
for cat in categories:
63+
category_index[cat["id"]] = cat
64+
return category_index
65+
66+
67+
def get_max_label_map_index(label_map):
68+
"""Get maximum index in label map.
69+
70+
Args:
71+
label_map: a StringIntLabelMapProto
72+
73+
Returns:
74+
an integer
75+
"""
76+
return max([item.id for item in label_map.item])
77+
78+
79+
def convert_label_map_to_categories(
80+
label_map, max_num_classes, use_display_name=True
81+
) -> List[Dict]:
82+
"""Loads label map proto and returns categories list compatible with eval.
83+
84+
This function loads a label map and returns a list of dicts, each of which
85+
has the following keys:
86+
'id': (required) an integer id uniquely identifying this category.
87+
'name': (required) string representing category name
88+
e.g., 'cat', 'dog', 'pizza'.
89+
We only allow class into the list if its id-label_id_offset is
90+
between 0 (inclusive) and max_num_classes (exclusive).
91+
If there are several items mapping to the same id in the label map,
92+
we will only keep the first one in the categories list.
93+
94+
Args:
95+
label_map: a StringIntLabelMapProto or None. If None, a default categories
96+
list is created with max_num_classes categories.
97+
max_num_classes: maximum number of (consecutive) label indices to include.
98+
use_display_name: (boolean) choose whether to load 'display_name' field
99+
as category name. If False or if the display_name field does not exist,
100+
uses 'name' field as category names instead.
101+
Returns:
102+
categories: a list of dictionaries representing all possible categories.
103+
"""
104+
categories = []
105+
list_of_ids_already_added: List = []
106+
if not label_map:
107+
label_id_offset = 1
108+
for class_id in range(max_num_classes):
109+
categories.append(
110+
{
111+
"id": class_id + label_id_offset,
112+
"name": "category_{}".format(class_id + label_id_offset),
113+
}
114+
)
115+
return categories
116+
for item in label_map.item:
117+
if not 0 < item.id <= max_num_classes:
118+
logging.info(
119+
"Ignore item %d since it falls outside of requested " "label range.",
120+
item.id,
121+
)
122+
continue
123+
if use_display_name and item.HasField("display_name"):
124+
name = item.display_name
125+
else:
126+
name = item.name
127+
if item.id not in list_of_ids_already_added:
128+
list_of_ids_already_added.append(item.id)
129+
categories.append({"id": item.id, "name": name})
130+
return categories
131+
132+
133+
def load_labelmap(path: str):
134+
"""Loads label map proto.
135+
136+
Args:
137+
path: path to StringIntLabelMap proto text file.
138+
Returns:
139+
a StringIntLabelMapProto
140+
"""
141+
with open(path, "r") as fid:
142+
label_map_string = fid.read()
143+
label_map = string_int_label_map_pb2.StringIntLabelMap()
144+
try:
145+
text_format.Merge(label_map_string, label_map)
146+
except text_format.ParseError:
147+
label_map.ParseFromString(label_map_string)
148+
_validate_label_map(label_map)
149+
return label_map
150+
151+
152+
def get_label_map_dict(label_map_path, use_display_name=False):
153+
"""Reads a label map and returns a dictionary of label names to id.
154+
155+
Args:
156+
label_map_path: path to label_map.
157+
use_display_name: whether to use the label map items' display names as keys.
158+
159+
Returns:
160+
A dictionary mapping label names to id.
161+
"""
162+
label_map = load_labelmap(label_map_path)
163+
label_map_dict = {}
164+
for item in label_map.item:
165+
if use_display_name:
166+
label_map_dict[item.display_name] = item.id
167+
else:
168+
label_map_dict[item.name] = item.id
169+
return label_map_dict
170+
171+
172+
def create_category_index_from_labelmap(label_map_path):
173+
"""Reads a label map and returns a category index.
174+
175+
Args:
176+
label_map_path: Path to `StringIntLabelMap` proto text file.
177+
178+
Returns:
179+
A category index, which is a dictionary that maps integer ids to dicts
180+
containing categories, e.g.
181+
{1: {'id': 1, 'name': 'dog'}, 2: {'id': 2, 'name': 'cat'}, ...}
182+
"""
183+
label_map = load_labelmap(label_map_path)
184+
max_num_classes = max(item.id for item in label_map.item)
185+
categories = convert_label_map_to_categories(label_map, max_num_classes)
186+
return create_category_index(categories)
187+
188+
189+
def create_class_agnostic_category_index():
190+
"""Creates a category index with a single `object` class."""
191+
return {1: {"id": 1, "name": "object"}}

0 commit comments

Comments
 (0)