-
Notifications
You must be signed in to change notification settings - Fork 8
/
Copy pathfeature_extract.py
80 lines (75 loc) · 3.22 KB
/
feature_extract.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
# -*- coding: utf-8 -*-
import numpy as np
import caffe
import hickle as hkl
class CaffeFeatureExtractor:
def __init__(self, model_path, pretrained_path, blob, crop_size, meanfile_path=None, mean_values=None):
caffe.set_mode_gpu()
self.model_path = model_path
self.pretrained_path = pretrained_path
self.blob = blob
self.crop_size = crop_size
self.meanfile_path = meanfile_path
self.mean_values = mean_values
# create network
self.net = caffe.Net(self.model_path, self.pretrained_path, caffe.TEST)
self.net.blobs["data"].reshape(1, 3, self.crop_size, self.crop_size)
# mean
if self.meanfile_path is not None:
# load mean array
self.mean = np.load(self.meanfile_path) # expect that shape = (1, C, H, W)
self.mean = self.mean[0]
self.mean = self.crop_matrix(self.mean, crop_size=self.crop_size)
elif self.mean_values is not None:
# create mean array
assert len(self.mean_values) == 3
self.mean = np.zeros((3, self.crop_size, self.crop_size))
self.mean[0] = mean_values[0]
self.mean[1] = mean_values[1]
self.mean[2] = mean_values[2]
else:
raise Exception
# create preprocessor
# Note: caffe.io.load_image() => (H,W,C), RGB, [0.0, 1.0]
self.transformer = caffe.io.Transformer({"data": self.net.blobs["data"].data.shape}) # for cropping
self.transformer.set_transpose("data", (2,0,1)) # (H,W,C) => (C,H,W)
self.transformer.set_mean("data", self.mean) # subtract by mean
self.transformer.set_raw_scale("data", 255) # [0.0, 1.0] => [0.0, 255.0].
self.transformer.set_channel_swap("data", (2,1,0)) # RGB => BGR
def extract_feature(self, img):
preprocessed_img = self.transformer.preprocess("data", img)
out = self.net.forward_all(**{self.net.inputs[0]: preprocessed_img, "blobs": [self.blob]})
feat = out[self.blob]
feat = feat[0]
return feat
def crop_matrix(self, matrix, crop_size):
"""
:param matrix numpy.ndarray: matrix, shape = [C,H,W]
:param crop_size integer: cropping size
:return: cropped matrix
:rtype: numpy.ndarray, shape = [C,H,W]
"""
assert matrix.shape[1] == matrix.shape[2]
corner_size = matrix.shape[1] - crop_size
corner_size = np.floor(corner_size / 2)
res = matrix[:, corner_size:crop_size+corner_size, corner_size:crop_size+corner_size]
return res
def create_dataset(net, datalist, dbprefix):
with open(datalist) as fr:
lines = fr.readlines()
lines = [line.rstrip() for line in lines]
feats = []
labels = []
for line_i, line in enumerate(lines):
img_path, label = line.split()
img = caffe.io.load_image(img_path)
feat = net.extract_feature(img)
feats.append(feat)
label = int(label)
labels.append(label)
if (line_i + 1) % 100 == 0:
print "processed", line_i + 1
feats = np.asarray(feats)
labels = np.asarray(labels)
hkl.dump(feats, dbprefix + "_features.hkl", mode="w")
hkl.dump(labels, dbprefix + "_labels.hkl", mode="w")