Skip to content
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.

Commit 0598235

Browse files
committedSep 3, 2018
added the heatmap tutorial code
1 parent 1a93061 commit 0598235

File tree

1 file changed

+185
-0
lines changed

1 file changed

+185
-0
lines changed
 

‎heatmap/generate_heatmap.py

+185
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,185 @@
1+
#!/usr/bin/python
2+
#
3+
# Copyright 2018 Orpix Inc.
4+
#
5+
# This script will iterate through a file "labels.txt" which specifies the images and object annotations to
6+
# compute an object exposure heatmap saved as heatmap.png in the location of this script.
7+
# It will highlight the labels on each annotated image and save a copy in the output folder.
8+
9+
import os
10+
import os.path
11+
12+
#we require numpy, opencv, and matplotlib
13+
import numpy as np
14+
import cv2
15+
import matplotlib
16+
17+
#this disables GUI windows from popping up
18+
matplotlib.use('Agg')
19+
import matplotlib.pyplot as plt
20+
21+
22+
#set working directory to location of script:
23+
abspath = os.path.abspath(__file__)
24+
dname = os.path.dirname(abspath)
25+
os.chdir(dname)
26+
27+
#helper function that parses the line
28+
#returns the frame, along with an array of array of point pairs representing detected logos
29+
def parse_line(line):
30+
31+
#format of a line is space delineated in the following format
32+
#
33+
#image_file_path number_of_labels x1 y1 x2 y2 x3 y3 x4 y4 x1 y1 x2 y2 x3 y3 x4 y4 ... etc
34+
#
35+
#Note: We use 4 x,y coordinates to denote a detection label
36+
#since Orpix logo detection outputs a quadrilateral as opposed to a rectangle
37+
toks = line.split(' ')
38+
39+
frame_path = toks[0]
40+
label_count = int(toks[1])
41+
42+
#getting points only, casting to int
43+
pts = toks[2:]
44+
pts = [ int(x) for x in pts ]
45+
46+
labels = []
47+
for i in range(label_count):
48+
49+
#get the 8 points for the current label
50+
label_pts = pts[i*8:i*8+8]
51+
52+
#reshape to array of 4 tuples
53+
pt_pairs = []
54+
for ptind in range(0,8,2):
55+
pt_pairs.append([label_pts[ptind], label_pts[ptind+1]])
56+
57+
#add this label to the list of labels we return
58+
pt_pairs = np.array(pt_pairs)
59+
labels.append(pt_pairs)
60+
61+
return frame_path, np.array(labels)
62+
63+
#help function to highlight the labels in each input frame
64+
#by interpolating the background with white
65+
def highlight_labels(img, labels, maskimg = None):
66+
67+
#create a copy of the image so we can draw on it
68+
imgcpy = img.copy()
69+
70+
#draw a quadrilateral for each label in red
71+
cv2.polylines(imgcpy, labels, True, (0,0,255), thickness=2)
72+
73+
#a mask needs to be created from the labels so we can properly highlight.
74+
#if the mask isn't passed in, we create it
75+
if type(maskimg) == type(None):
76+
maskimg = np.zeros(imgcpy.shape, dtype=np.float)
77+
for label in labels:
78+
#this sets all pixels inside the label to 1
79+
cv2.fillConvexPoly(maskimg, label, (1))
80+
81+
#create a rgb version of the mask by setting each channel to the mask we created
82+
maskimg = (maskimg>0).astype(np.uint8)
83+
maskrgb = np.zeros(imgcpy.shape, np.uint8)
84+
maskrgb[:,:,0] = maskimg
85+
maskrgb[:,:,1] = maskimg
86+
maskrgb[:,:,2] = maskimg
87+
88+
#interpolate image with white using a weighted sum
89+
bgimg = .5*255*np.ones(imgcpy.shape, np.float) + .5*imgcpy.astype(np.float)
90+
#mask out the background
91+
bgimg = (1-maskrgb)*bgimg
92+
#cast to uint8 image
93+
bgimg = np.round(bgimg).astype(np.uint8)
94+
95+
#get foreground unchanged
96+
fgimg = maskrgb*imgcpy
97+
98+
#add white tinted background with unchanged foreground
99+
imgcpy = bgimg + fgimg.astype(np.uint8)
100+
101+
return imgcpy
102+
103+
def main():
104+
105+
#keeps track of exposure time per pixel. Accumulates for each image
106+
accumulated_exposures = None
107+
108+
#frames were sampled at one second per frame. If you sampled frames from a video at a different rate, change this value
109+
seconds_per_frame = 1.0 #if you sampled frames at 10 frames per second, this value would be 0.1
110+
111+
#we open the labels file and will iterate through each line.
112+
#each line contains a reference to the image and the corresponding polygon lables (4 points per label)
113+
#each frame in the labels file was extracted from one video
114+
with open('labels.txt') as f:
115+
lines = f.readlines()
116+
for line in lines:
117+
118+
#parse the line using helper function
119+
frame_path, labels = parse_line(line)
120+
121+
print "processing %s" % frame_path
122+
123+
#load the image
124+
frame = cv2.imread(frame_path)
125+
126+
#this is where the highlighted images will go
127+
if not os.path.exists('output'):
128+
os.mkdir('output')
129+
130+
131+
#if the heatmap is None we create it with same size as frame, single channel
132+
if type(accumulated_exposures) == type(None):
133+
accumulated_exposures = np.zeros((frame.shape[0], frame.shape[1]), dtype=np.float)
134+
135+
#we create a mask where all pixels inside each label are set to number of seconds per frame that the video was sampled at
136+
#so as we accumulate the exposure heatmap counts, each pixel contained inside a label contributes the seconds_per_frame
137+
#to the overall accumulated exposure values
138+
maskimg = np.zeros(accumulated_exposures.shape, dtype=np.float)
139+
for label in labels:
140+
cv2.fillConvexPoly(maskimg, label, (seconds_per_frame))
141+
142+
#highlight the labels on the image and save.
143+
#comment out the 2 lines below if you only want to compute the heatmap
144+
highlighted_image = highlight_labels(frame, labels, maskimg)
145+
cv2.imwrite('output/%s' % os.path.basename(frame_path), highlighted_image)
146+
147+
#accumulate the heatmap object exposure time
148+
accumulated_exposures = accumulated_exposures + maskimg
149+
150+
151+
#
152+
#create final heatmap using matplotlib
153+
#
154+
155+
data = np.array(accumulated_exposures)
156+
#create the figure
157+
fig, axis = plt.subplots()
158+
#set the colormap - there are many options for colormaps - see documentation
159+
#we will use cm.jet
160+
hm = axis.pcolor(data, cmap=plt.cm.jet)
161+
#set axis ranges
162+
axis.set(xlim=[0, data.shape[1]], ylim=[0, data.shape[0]], aspect=1)
163+
#need to invert coordinate for images
164+
axis.invert_yaxis()
165+
#remove the ticks
166+
axis.set_xticks([])
167+
axis.set_yticks([])
168+
169+
#fit the colorbar to the height
170+
shrink_scale = 1.0
171+
aspect = data.shape[0]/float(data.shape[1])
172+
if aspect < 1.0:
173+
shrink_scale = aspect
174+
clb = plt.colorbar(hm, shrink=shrink_scale)
175+
#set title
176+
clb.ax.set_title('Exposure (seconds)', fontsize = 10)
177+
#saves image to same directory that the script is located in (our working directory)
178+
plt.savefig('heatmap.png', bbox_inches='tight')
179+
#close objects
180+
plt.close('all')
181+
182+
183+
if __name__ == '__main__':
184+
185+
main()

0 commit comments

Comments
 (0)
Please sign in to comment.