Skip to content

Commit 5a88cdc

Browse files
authored
Merge pull request #7 from twuilliam/guidedbp
Adding guided backprop
2 parents 3dd38f1 + 9de5f45 commit 5a88cdc

File tree

7 files changed

+101
-13
lines changed

7 files changed

+101
-13
lines changed

README.md

Lines changed: 21 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,9 +9,27 @@ The paper authors torch implementation: https://github.com/ramprs/grad-cam
99

1010
This code assumes Tensorflow dimension ordering, and uses the VGG16 network in keras.applications by default (the network weights will be downloaded on first use).
1111

12-
TODO: Combine with guided back propagation like in the paper.
13-
1412

1513
Usage: `python grad-cam.py <path_to_image>`
1614

17-
![enter image description here](https://github.com/jacobgil/keras-grad-cam/blob/master/examples/boat.jpg?raw=true) ![enter image description here](https://github.com/jacobgil/keras-grad-cam/blob/master/examples/persian_cat.jpg?raw=true)
15+
16+
##### Examples
17+
18+
![enter image description here](https://github.com/jacobgil/keras-grad-cam/blob/master/examples/boat.jpg?raw=true) ![enter image description here](https://github.com/jacobgil/keras-grad-cam/blob/master/examples/persian_cat.jpg?raw=true)
19+
20+
Example image from the [original implementation](https://github.com/ramprs/grad-cam):
21+
22+
'boxer' (243 or 242 in keras)
23+
24+
![](/examples/cat_dog.png)
25+
![](/examples/cat_dog_242_gradcam.jpg)
26+
![](/examples/cat_dog_242_guided_gradcam.jpg)
27+
28+
'tiger cat' (283 or 282 in keras)
29+
30+
![](/examples/cat_dog.png)
31+
![](/examples/cat_dog_282_gradcam.jpg)
32+
![](/examples/cat_dog_282_guided_gradcam.jpg)
33+
34+
35+

examples/cat_dog.png

88.1 KB
Loading

examples/cat_dog_242_gradcam.jpg

17.9 KB
Loading
8.27 KB
Loading

examples/cat_dog_282_gradcam.jpg

18 KB
Loading
11 KB
Loading

grad-cam.py

Lines changed: 80 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,18 @@
1-
from keras.applications.vgg16 import VGG16
1+
from keras.applications.vgg16 import (
2+
VGG16, preprocess_input, decode_predictions)
23
from keras.preprocessing import image
3-
from keras.applications.vgg16 import preprocess_input
44
from keras.layers.core import Lambda
55
from keras.models import Sequential
6+
from tensorflow.python.framework import ops
67
import keras.backend as K
78
import tensorflow as tf
89
import numpy as np
10+
import keras
911
import sys
1012
import cv2
1113

1214
def target_category_loss(x, category_index, nb_classes):
13-
return tf.mul(x, K.one_hot([category_index], nb_classes))
15+
return tf.multiply(x, K.one_hot([category_index], nb_classes))
1416

1517
def target_category_loss_output_shape(input_shape):
1618
return input_shape
@@ -27,7 +29,63 @@ def load_image(path):
2729
x = preprocess_input(x)
2830
return x
2931

30-
def grad_cam(input_model, image, category_index, layer_name):
32+
def register_gradient():
33+
if "GuidedBackProp" not in ops._gradient_registry._registry:
34+
@ops.RegisterGradient("GuidedBackProp")
35+
def _GuidedBackProp(op, grad):
36+
dtype = op.inputs[0].dtype
37+
return grad * tf.cast(grad > 0., dtype) * \
38+
tf.cast(op.inputs[0] > 0., dtype)
39+
40+
def compile_saliency_function(model, activation_layer='block5_conv3'):
41+
input_img = model.input
42+
layer_dict = dict([(layer.name, layer) for layer in model.layers[1:]])
43+
layer_output = layer_dict[activation_layer].output
44+
max_output = K.max(layer_output, axis=3)
45+
saliency = K.gradients(K.sum(max_output), input_img)[0]
46+
return K.function([input_img, K.learning_phase()], [saliency])
47+
48+
def modify_backprop(model, name):
49+
g = tf.get_default_graph()
50+
with g.gradient_override_map({'Relu': name}):
51+
52+
# get layers that have an activation
53+
layer_dict = [layer for layer in model.layers[1:]
54+
if hasattr(layer, 'activation')]
55+
56+
# replace relu activation
57+
for layer in layer_dict:
58+
if layer.activation == keras.activations.relu:
59+
layer.activation = tf.nn.relu
60+
61+
# re-instanciate a new model
62+
new_model = VGG16(weights='imagenet')
63+
return new_model
64+
65+
def deprocess_image(x):
66+
'''
67+
Same normalization as in:
68+
https://github.com/fchollet/keras/blob/master/examples/conv_filter_visualization.py
69+
'''
70+
if np.ndim(x) > 3:
71+
x = np.squeeze(x)
72+
# normalize tensor: center on 0., ensure std is 0.1
73+
x -= x.mean()
74+
x /= (x.std() + 1e-5)
75+
x *= 0.1
76+
77+
# clip to [0, 1]
78+
x += 0.5
79+
x = np.clip(x, 0, 1)
80+
81+
# convert to RGB array
82+
x *= 255
83+
if K.image_dim_ordering() == 'th':
84+
x = x.transpose((1, 2, 0))
85+
x = np.clip(x, 0, 255).astype('uint8')
86+
return x
87+
88+
def grad_cam(input_model, image, category_index, layer_name):
3189
model = Sequential()
3290
model.add(input_model)
3391

@@ -52,22 +110,34 @@ def grad_cam(input_model, image, category_index, layer_name):
52110

53111
cam = cv2.resize(cam, (224, 224))
54112
cam = np.maximum(cam, 0)
55-
cam = cam / np.max(cam)
113+
heatmap = cam / np.max(cam)
56114

57115
#Return to BGR [0..255] from the preprocessed image
58116
image = image[0, :]
59117
image -= np.min(image)
60118
image = np.minimum(image, 255)
61119

62-
cam = cv2.applyColorMap(np.uint8(255*cam), cv2.COLORMAP_JET)
120+
cam = cv2.applyColorMap(np.uint8(255*heatmap), cv2.COLORMAP_JET)
63121
cam = np.float32(cam) + np.float32(image)
64122
cam = 255 * cam / np.max(cam)
65-
return np.uint8(cam)
123+
return np.uint8(cam), heatmap
66124

67125
preprocessed_input = load_image(sys.argv[1])
68126

69127
model = VGG16(weights='imagenet')
70128

71-
predicted_class = np.argmax(model.predict(preprocessed_input))
72-
cam = grad_cam(model, preprocessed_input, predicted_class, "block5_pool")
73-
cv2.imwrite("cam.jpg", cam)
129+
predictions = model.predict(preprocessed_input)
130+
top_1 = decode_predictions(predictions)[0][0]
131+
print('Predicted class:')
132+
print('%s (%s) with probability %.2f' % (top_1[1], top_1[0], top_1[2]))
133+
134+
predicted_class = np.argmax(predictions)
135+
cam, heatmap = grad_cam(model, preprocessed_input, predicted_class, "block5_conv3")
136+
cv2.imwrite("gradcam.jpg", cam)
137+
138+
register_gradient()
139+
guided_model = modify_backprop(model, 'GuidedBackProp')
140+
saliency_fn = compile_saliency_function(guided_model)
141+
saliency = saliency_fn([preprocessed_input, 0])
142+
gradcam = saliency[0] * heatmap[..., np.newaxis]
143+
cv2.imwrite("guided_gradcam.jpg", deprocess_image(gradcam))

0 commit comments

Comments
 (0)