Skip to content

Commit b0b9d37

Browse files
committed
Test result is bad
1 parent 2cfbc58 commit b0b9d37

File tree

10 files changed

+137
-70
lines changed

10 files changed

+137
-70
lines changed

.gitignore

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -220,3 +220,5 @@ pip-log.txt
220220
#############
221221
*.mat
222222
*.asv
223+
*.png
224+
*.pts

ESR_Test.m

Lines changed: 73 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ function ESR_Test()
1414
disp('lack of initial set of shapes');
1515
end
1616

17-
Data = loadsamples('/Volumes/LG_SDJet/Datasets/lfpw/annotations/trainset', 'png');
17+
Data = loadsamples('lfpw/testset', 'png');
1818
params.N_img = size(Data, 1);
1919
load('Data/InitialShape_68');
2020
dist_pupils_ms = getDistPupils(S0);
@@ -23,78 +23,113 @@ function ESR_Test()
2323

2424
load('Data/Model.mat', 'Model');
2525
%%
26-
for i = 1: 1
27-
Prediction = ShapeRegression(Data{i}, initSet, Model, params);
28-
figure
29-
imshow(Data{i}.img_gray)
30-
hold on
31-
plot(Prediction(:, 1), Prediction(:, 2), 'g+');
32-
hold off
26+
prediction = zeros([size(params.meanshape), params.N_img]);
27+
groundtruth = zeros([size(params.meanshape), params.N_img]);
28+
for i = 1: params.N_img
29+
Prediction = ShapeRegression(Data(i), initSet, Model, params);
30+
prediction(:,:, i) = Prediction;
31+
groundtruth(:,:, i) = Data{i}.shape_gt(params.ind_usedpts,:);
3332
end
34-
33+
fprintf('MSRE is %f\n', mean(compute_error(prediction, groundtruth)));
3534
end
3635

3736
function predict = ShapeRegression(data, initSet, Model, params)
3837
% Multiple initializations
39-
40-
%predict = zeros(params.N_fp, 2);
41-
ctshapes = initialTest(data, initSet, params);
38+
Data = initialize(data, initSet, params);
4239

4340
for t = 1: params.T
44-
for i = 1: params.N_init
45-
prediction_delta = fernCascadeTest(data, ctshapes(:, :, i), Model{t}.fernCascade, params);
46-
ctshapes(:, :, i) = ctshapes(:, :, i) + prediction_delta;
41+
for i = 1: params.N_aug
42+
prediction_delta = fernCascadeTest(Data{i}, Model{t}.fernCascade, params, t);
43+
% update the shape, convert to the current shape
44+
bbx = Data{i}.intermediate_bboxes{t};
45+
shape_stage = Data{i}.intermediate_shapes{t};
46+
delta_shape = prediction_delta;
47+
48+
[u, v] = transformPointsForward(Data{i}.meanshape2tf, delta_shape(:, 1), delta_shape(:, 2));
49+
delta_shape_interm_coord = [u, v];
50+
shape_residual = bsxfun(@times, delta_shape_interm_coord, [bbx(3),bbx(4)]);
51+
shape_newstage = shape_stage + shape_residual;
52+
53+
% update the shape
54+
Data{i}.intermediate_bboxes{t+1} = getbbox(shape_newstage);
55+
Data{i}.intermediate_shapes{t+1} = shape_newstage;
56+
meanshape_reproject = resetshape(Data{i}.intermediate_bboxes{t+1}, params.meanshape);
57+
Data{i}.tf2meanshape = fitgeotrans( bsxfun(@minus, shape_newstage, mean(shape_newstage)), ...
58+
bsxfun(@minus, meanshape_reproject, mean(meanshape_reproject)),...
59+
'nonreflectivesimilarity');
60+
Data{i}.meanshape2tf = fitgeotrans( bsxfun(@minus, meanshape_reproject, mean(meanshape_reproject)),...
61+
bsxfun(@minus, shape_newstage, mean(shape_newstage)), ...
62+
'nonreflectivesimilarity');
63+
% shape_residual = bsxfun(@rdivide, Data{i}.shape_gt - shape_newstage, Data{i}.intermediate_bboxes{t+1}(3:4));
64+
% [u, v] = transformPointsForward(Data{i}.tf2meanshape, shape_residual(:, 1), shape_residual(:, 2));
65+
% Data{i}.shapes_residual = [u, v];
4766
end
4867
end
49-
predict = mean(ctshapes, 3);
68+
69+
% Prediction
70+
% gtshapes = zeros([size(params.meanshape), params.T+1]);
71+
% ctshapes = zeros([size(params.meanshape), params.T+1]);
72+
% for t = 1: params.T+1
73+
% for i = 1: params.N_aug
74+
% ctshapes(:,:, t) = ctshapes(:,:, t) + Data{i}.intermediate_shapes{t};
75+
% gtshapes(: ,:, t) = gtshapes(: ,:, t) + Data{i}.shape_gt;
76+
% end
77+
% end
78+
% ctshapes = ctshapes/params.N_aug;
79+
% gtshapes = gtshapes/params.N_aug;
80+
% Error = zeros(1, params.T+1);
81+
% for t = 1:params.T
82+
% Error(t) = compute_error(ctshapes(:,:, t), gtshapes(:,:, t));
83+
% end
84+
% bar(Error);
85+
86+
predict = zeros([size(params.meanshape), params.N_aug]);
87+
for i = 1:params.N_aug
88+
predict(:, :, i) = Data{i}.intermediate_shapes{end};
89+
end
90+
predict = mean(predict, 3);
5091
end
5192

52-
function prediction_delta = fernCascadeTest(image, current_shape, fernCascade, params)
53-
image.intermediate_bbx = getbbox(current_shape);
54-
meanshape = resetshape(image.intermediate_bbx, params.meanshape);
55-
image.tf2meanshape = fitgeotrans(bsxfun(@minus, current_shape, mean(current_shape)), ...
56-
bsxfun(@minus, meanshape, mean(meanshape)),...
57-
'nonreflectivesimilarity');
58-
image.meanshape2tf = fitgeotrans( bsxfun(@minus, meanshape, mean(meanshape)),...
59-
bsxfun(@minus, current_shape, mean(current_shape)), ...
60-
'nonreflectivesimilarity');
61-
93+
function delta_shape = fernCascadeTest(image, fernCascade, params, t)
6294
%extract shape indexed pixels
6395
candidate_pixel_location = fernCascade.candidate_pixel_location;
6496
nearest_landmark_index = fernCascade.nearest_landmark_index;
6597
intensities = zeros(1, params.P);
6698
for j = 1: params.P
67-
x = candidate_pixel_location(j, 1)*image.intermediate_bbx(3);
68-
y = candidate_pixel_location(j, 2)* image.intermediate_bbx(4);
99+
x = candidate_pixel_location(j, 1)*image.intermediate_bboxes{t}(3);
100+
y = candidate_pixel_location(j, 2)* image.intermediate_bboxes{t}(4);
69101
[project_x, project_y] = transformPointsForward(image.meanshape2tf, x, y);
70102
index = nearest_landmark_index(j);
71103

72-
real_x = round(project_x + current_shape(index, 1));
73-
real_y = round(project_y + current_shape(index, 2));
104+
real_x = round(project_x + image.intermediate_shapes{t}(index, 1));
105+
real_y = round(project_y + image.intermediate_shapes{t}(index, 2));
74106
real_x = max(1, min(real_x, size(image.img_gray, 2)-1));
75107
real_y = max(1, min(real_y, size(image.img_gray, 1)-1));
76108
intensities(j)= image.img_gray(real_y, real_x);
77109
end
78110

79111
delta_shape = zeros(size(params.meanshape));
80-
for i = 1: params.K
81-
fern = fernCascade.ferns{i}.fern;
112+
for k = 1: params.K
113+
fern = fernCascade.ferns{k}.fern;
82114
delta_shape = delta_shape + fernTest(intensities, fern, params);
83115
end
84116

85117
%convert to the currentshape model
86-
[u, v] = transformPointsForward(image.meanshape2tf, delta_shape(:, 1), delta_shape(:, 2));
87-
prediction_delta = [u, v];
88-
prediction_delta = bsxfun(@times, prediction_delta, [image.intermediate_bbx(3),image.intermediate_bbx(4)]);
118+
% [u, v] = transformPointsForward(image.meanshape2tf, delta_shape(:, 1), delta_shape(:, 2));
119+
% prediction_delta = [u, v];
120+
% prediction_delta = bsxfun(@times, prediction_delta, [image.intermediate_bbx(3),image.intermediate_bbx(4)]);
89121
end
90122

91123
function fern_pred = fernTest(intensities, fern, params)
92124
index = 0;
93125
for i = 1: params.F
94-
intensity_1 = intensities(fern.selected_pixel_index(i, 1));
95-
intensity_2 = intensities(fern.selected_pixel_index(i, 2));
126+
m_f = fern.selected_pixel_index(i, 1);
127+
n_f = fern.selected_pixel_index(i, 2);
128+
129+
intensity_1 = intensities(m_f);
130+
intensity_2 = intensities(n_f);
96131

97-
if (intensity_1 - intensity_2) >= fern.threshold(i)
132+
if intensity_1 - intensity_2 >= fern.threshold(i)
98133
index = index + 2^(i-1);
99134
end
100135
end

ESR_Train.m

100755100644
Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,13 +3,13 @@ function ESR_Train()
33
params = Train_params;
44
% create paralllel local jobs note
55
if isempty(gcp('nocreate'))
6-
parpool(2);
6+
parpool(8);
77
end
88
%% load data
99
if exist('Data/train_init.mat', 'file')
1010
load('Data/train_init.mat', 'data');
1111
else
12-
data = loadsamples('/Volumes/LG_SDJet/Datasets/lfpw/annotations/trainset', 'png');
12+
data = loadsamples('/lfpw/annotations/trainset', 'png');
1313
%mkdir Data;
1414
save('Data/train_init.mat', 'data');
1515
end

Test_params.m

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,12 @@
11
function params = Test_params()
22
params.T = 10; % the iteration stages
3-
params.P = 40; % default = 400, the pixel number sampled on the images
4-
params.K = 50; % default = 500, the number of fern on the internal-level boosted regression
3+
params.P = 400; % default = 400, the pixel number sampled on the images
4+
params.K = 500; % default = 500, the number of fern on the internal-level boosted regression
55
params.k = [0.3, 0.3, 0.3, 0.3, 0.3, 0.3, 0.3, 0.3, 0.3, 0.3];
66
% the local scale of search, it set 0.3 times of the distance between two pupils on the mean shape
77
params.F = 5; % the number of features in fern
88

9-
params.N_init = 5; % initial number
9+
params.N_aug = 5; % initial number
1010
%params.N_aug = 1;
1111

1212
params.N_fp = 0;% size(params.mean_shape, 1);

Train.fig

12.9 KB
Binary file not shown.

Train_params.m

100755100644
File mode changed.

compute_error.m

100755100644
File mode changed.

getbbox.m

100755100644
File mode changed.

initialTest.m

Lines changed: 0 additions & 27 deletions
This file was deleted.

initialize.m

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,57 @@
1+
function Data = initialize(data, initSet, params)
2+
% initialize the data
3+
% data: image
4+
% initSet: the inital set of initialization
5+
% D: the number of initialization
6+
Index = 0;
7+
N_data = size(data, 1);
8+
9+
Data = cell(N_data*params.N_aug, 1);
10+
11+
for i = 1: N_data
12+
% random select initial shapes without replacement
13+
rand_index_ = randperm(params.N_img, params.N_aug);
14+
while ismember(i, rand_index_)
15+
rand_index_ = randperm(params.N_img, params.N_aug);
16+
end
17+
% expand the data
18+
for j = 1: params.N_aug
19+
r_index = rand_index_(j);
20+
21+
Index = Index + 1;
22+
% copy the original stuff
23+
Data{Index}.img_gray = data{i}.img_gray;
24+
Data{Index}.width_orig = data{i}.width_orig;
25+
Data{Index}.height_orig = data{i}.height_orig;
26+
Data{Index}.width = data{i}.width;
27+
Data{Index}.height = data{i}.height;
28+
Data{Index}.shape_gt = data{i}.shape_gt(params.ind_usedpts, :);
29+
Data{Index}.bbox_gt = getbbox(Data{Index}.shape_gt);
30+
% add the new element
31+
Data{Index}.intermediate_shapes = cell(1, params.T);
32+
Data{Index}.intermediate_bboxes = cell(1, params.T);
33+
% scale and translate the sampled shape to ground-truth
34+
% face rectangle region
35+
select_shape = resetshape(data{i}.bbox_gt, initSet{r_index}.shape_gt(params.ind_usedpts, :));
36+
37+
Data{Index}.intermediate_shapes{1} = select_shape;
38+
Data{Index}.intermediate_bboxes{1} = getbbox(select_shape);
39+
40+
meanshape_resize = resetshape(Data{Index}.intermediate_bboxes{1}, params.meanshape);
41+
42+
Data{Index}.tf2meanshape = fitgeotrans(bsxfun(@minus, ...
43+
Data{Index}.intermediate_shapes{1}, mean(Data{Index}.intermediate_shapes{1})), ...
44+
bsxfun(@minus, meanshape_resize, mean(meanshape_resize)),...
45+
'nonreflectivesimilarity');
46+
Data{Index}.meanshape2tf = fitgeotrans(bsxfun(@minus, meanshape_resize, mean(meanshape_resize)), ...
47+
bsxfun(@minus, Data{Index}.intermediate_shapes{1}, mean(Data{Index}.intermediate_shapes{1})),...
48+
'nonreflectivesimilarity');
49+
50+
shape_residual = bsxfun(@rdivide, Data{Index}.shape_gt - Data{Index}.intermediate_shapes{1},...
51+
Data{Index}.intermediate_bboxes{1}(3: 4));
52+
53+
[u, v] = transformPointsForward(Data{Index}.tf2meanshape, shape_residual(:, 1), shape_residual(:, 2));
54+
Data{Index}.shapes_residual = [u, v];
55+
end
56+
end
57+
end

0 commit comments

Comments
 (0)