Skip to content

Commit 158d27a

Browse files
committed
add legend to plotDetection
1 parent f8213f5 commit 158d27a

File tree

3 files changed

+43
-8
lines changed

3 files changed

+43
-8
lines changed

MTM/__init__.py

Lines changed: 41 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,8 @@
66
"""
77
import numpy as np
88
import matplotlib.pyplot as plt
9+
import warnings
10+
from matplotlib.lines import Line2D
911
from skimage import feature
1012
from .NMS import NMS
1113
from .Detection import BoundingBox
@@ -111,7 +113,7 @@ def findMatches(image,
111113

112114
def matchTemplates(image,
113115
listTemplates,
114-
listLabels=[],
116+
listLabels=None,
115117
score_threshold=0.5,
116118
maxOverlap=0.25,
117119
nObjects=float("inf"),
@@ -151,7 +153,7 @@ def matchTemplates(image,
151153
return bestHits
152154

153155

154-
def plotDetections(image, listDetections, thickness=2):
156+
def plotDetections(image, listDetections, thickness=2, showLegend=False):
155157
"""
156158
Plot the detections overlaid on the image.
157159
@@ -171,9 +173,10 @@ def plotDetections(image, listDetections, thickness=2):
171173
- thickness (optional, default=2): int
172174
thickness of plotted contour in pixels
173175
174-
- showLabel: Boolean
175-
Display label of the bounding box (field TemplateName)
176-
Not implemented
176+
- showLegend (optional, default=False): Boolean
177+
Display a legend panel with the category labels for each color.
178+
This works if the Detections have a label
179+
(not just "", in which case the legend is not shown).
177180
"""
178181
plt.figure()
179182
plt.imshow(image, cmap="gray") # cmap gray only impacts gray images
@@ -184,9 +187,41 @@ def plotDetections(image, listDetections, thickness=2):
184187
palette = plt.cm.Set3.colors
185188
nColors = len(palette)
186189

190+
if showLegend:
191+
mapLabelColor = {}
192+
187193
for detection in listDetections:
194+
195+
# Get color for this category
188196
colorIndex = detection.get_template_index() % nColors # will return an integer in the range of palette
197+
color = palette[colorIndex]
189198

190199
plt.plot(*detection.get_lists_xy(),
191200
linewidth=thickness,
192-
color=palette[colorIndex])
201+
color=color)
202+
203+
# If show legend, get detection label and current color
204+
if showLegend:
205+
206+
label = detection.get_label()
207+
208+
if label != "":
209+
mapLabelColor[label] = color
210+
211+
# Finally add the legend if mapLabelColor is not empty
212+
if showLegend :
213+
214+
if not mapLabelColor: # Empty label mapping
215+
warnings.warn("No label associated to the templates." +
216+
"Skipping legend.")
217+
218+
else: # meaning mapLabelColor is not empty
219+
220+
legendLabels = []
221+
legendEntries = []
222+
223+
for label, color in mapLabelColor.items():
224+
legendLabels.append(label)
225+
legendEntries.append(Line2D([0], [0], color=color, lw=4))
226+
227+
plt.legend(legendEntries, legendLabels)

test.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,4 +33,4 @@
3333
print (np.array(finalHits)) # better formatting with array
3434

3535
#%% Display matches
36-
MTM.plotDetections(image, finalHits)
36+
MTM.plotDetections(image, finalHits, showLegend=True)

test_RGB.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,4 +28,4 @@
2828
print (np.array(finalHits)) # better formatting with array
2929

3030
#%% Display matches
31-
MTM.plotDetections(image, finalHits)
31+
MTM.plotDetections(image, finalHits, showLegend = True)

0 commit comments

Comments
 (0)