Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion caffevis/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,8 @@ def __init__(self, settings, key_bindings):
settings.caffevis_deploy_prototxt,
settings.caffevis_network_weights,
mean = None, # Set to None for now, assign later # self._data_mean,
channel_swap = self._net_channel_swap,
channel_swap = None, #Set to None to extend support to single channel
#self._net_channel_swap,
raw_scale = self._range_scale,
)

Expand Down Expand Up @@ -85,6 +86,8 @@ def __init__(self, settings, key_bindings):
# self._data_mean = tuple(self._data_mean)
if self._data_mean is not None:
self.net.transformer.set_mean(self.net.inputs[0], self._data_mean)
if self.net.blobs[self.net.inputs[0]].data.shape[1] == 3: #RGB
self.net.transformer.set_channel_swap(self.net.inputs[0], self._net_channel_swap)

check_force_backward_true(settings.caffevis_deploy_prototxt)

Expand Down
10 changes: 8 additions & 2 deletions caffevis/caffevis_helper.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,19 @@
import numpy as np
import cv2

from image_misc import get_tiles_height_width, caffe_load_image



def net_preproc_forward(net, img, data_hw):
appropriate_shape = data_hw + (3,)
assert img.shape == appropriate_shape, 'img is wrong size (got %s but expected %s)' % (img.shape, appropriate_shape)
#resized = caffe.io.resize_image(img, net.image_dims) # e.g. (227, 227, 3)
if net.blobs[net.inputs[0]].data.shape[1] == 1:
appropriate_shape = data_hw + (1,)
bw_img = 0.299*img[:,:,0] + 0.587*img[:,:,1] + 0.114*img[:,:,2]
img = bw_img.reshape(data_hw[0],data_hw[1],1)
elif net.blobs[net.inputs[0]].data.shape[1] == 3:
appropriate_shape = data_hw + (3,)
assert img.shape == appropriate_shape, 'img is wrong size (got %s but expected %s)' % (img.shape, appropriate_shape)
data_blob = net.transformer.preprocess('data', img) # e.g. (3, 227, 227), mean subtracted and scaled to [0,255]
data_blob = data_blob[np.newaxis,:,:,:] # e.g. (1, 3, 227, 227)
output = net.forward(data=data_blob)
Expand Down