@@ -14,7 +14,7 @@ function ESR_Test()
14
14
disp(' lack of initial set of shapes' );
15
15
end
16
16
17
- Data = loadsamples(' /Volumes/LG_SDJet/Datasets/ lfpw/annotations/trainset ' , ' png' );
17
+ Data = loadsamples(' lfpw/testset ' , ' png' );
18
18
params.N_img = size(Data , 1 );
19
19
load(' Data/InitialShape_68' );
20
20
dist_pupils_ms = getDistPupils(S0 );
@@ -23,78 +23,113 @@ function ESR_Test()
23
23
24
24
load(' Data/Model.mat' , ' Model' );
25
25
%%
26
- for i = 1 : 1
27
- Prediction = ShapeRegression(Data{i }, initSet , Model , params );
28
- figure
29
- imshow(Data{i }.img_gray)
30
- hold on
31
- plot(Prediction(: , 1 ), Prediction(: , 2 ), ' g+' );
32
- hold off
26
+ prediction = zeros([size(params .meanshape ), params .N_img ]);
27
+ groundtruth = zeros([size(params .meanshape ), params .N_img ]);
28
+ for i = 1 : params .N_img
29
+ Prediction = ShapeRegression(Data(i ), initSet , Model , params );
30
+ prediction(: ,: , i ) = Prediction ;
31
+ groundtruth(: ,: , i ) = Data{i }.shape_gt(params .ind_usedpts ,: );
33
32
end
34
-
33
+ fprintf( ' MSRE is %f\n ' , mean(compute_error( prediction , groundtruth )));
35
34
end
36
35
37
36
function predict = ShapeRegression(data , initSet , Model , params )
38
37
% Multiple initializations
39
-
40
- % predict = zeros(params.N_fp, 2);
41
- ctshapes = initialTest(data , initSet , params );
38
+ Data = initialize(data , initSet , params );
42
39
43
40
for t = 1 : params .T
44
- for i = 1 : params .N_init
45
- prediction_delta = fernCascadeTest(data , ctshapes(: , : , i ), Model{t }.fernCascade, params );
46
- ctshapes(: , : , i ) = ctshapes(: , : , i ) + prediction_delta ;
41
+ for i = 1 : params .N_aug
42
+ prediction_delta = fernCascadeTest(Data{i }, Model{t }.fernCascade, params , t );
43
+ % update the shape, convert to the current shape
44
+ bbx = Data{i }.intermediate_bboxes{t };
45
+ shape_stage = Data{i }.intermediate_shapes{t };
46
+ delta_shape = prediction_delta ;
47
+
48
+ [u , v ] = transformPointsForward(Data{i }.meanshape2tf, delta_shape(: , 1 ), delta_shape(: , 2 ));
49
+ delta_shape_interm_coord = [u , v ];
50
+ shape_residual = bsxfun(@times , delta_shape_interm_coord , [bbx(3 ),bbx(4 )]);
51
+ shape_newstage = shape_stage + shape_residual ;
52
+
53
+ % update the shape
54
+ Data{i }.intermediate_bboxes{t + 1 } = getbbox(shape_newstage );
55
+ Data{i }.intermediate_shapes{t + 1 } = shape_newstage ;
56
+ meanshape_reproject = resetshape(Data{i }.intermediate_bboxes{t + 1 }, params .meanshape );
57
+ Data{i }.tf2meanshape = fitgeotrans( bsxfun(@minus , shape_newstage , mean(shape_newstage )), ...
58
+ bsxfun(@minus , meanshape_reproject , mean(meanshape_reproject )),...
59
+ ' nonreflectivesimilarity' );
60
+ Data{i }.meanshape2tf = fitgeotrans( bsxfun(@minus , meanshape_reproject , mean(meanshape_reproject )),...
61
+ bsxfun(@minus , shape_newstage , mean(shape_newstage )), ...
62
+ ' nonreflectivesimilarity' );
63
+ % shape_residual = bsxfun(@rdivide, Data{i}.shape_gt - shape_newstage, Data{i}.intermediate_bboxes{t+1}(3:4));
64
+ % [u, v] = transformPointsForward(Data{i}.tf2meanshape, shape_residual(:, 1), shape_residual(:, 2));
65
+ % Data{i}.shapes_residual = [u, v];
47
66
end
48
67
end
49
- predict = mean(ctshapes , 3 );
68
+
69
+ % Prediction
70
+ % gtshapes = zeros([size(params.meanshape), params.T+1]);
71
+ % ctshapes = zeros([size(params.meanshape), params.T+1]);
72
+ % for t = 1: params.T+1
73
+ % for i = 1: params.N_aug
74
+ % ctshapes(:,:, t) = ctshapes(:,:, t) + Data{i}.intermediate_shapes{t};
75
+ % gtshapes(: ,:, t) = gtshapes(: ,:, t) + Data{i}.shape_gt;
76
+ % end
77
+ % end
78
+ % ctshapes = ctshapes/params.N_aug;
79
+ % gtshapes = gtshapes/params.N_aug;
80
+ % Error = zeros(1, params.T+1);
81
+ % for t = 1:params.T
82
+ % Error(t) = compute_error(ctshapes(:,:, t), gtshapes(:,:, t));
83
+ % end
84
+ % bar(Error);
85
+
86
+ predict = zeros([size(params .meanshape ), params .N_aug ]);
87
+ for i = 1 : params .N_aug
88
+ predict(: , : , i ) = Data{i }.intermediate_shapes{end };
89
+ end
90
+ predict = mean(predict , 3 );
50
91
end
51
92
52
- function prediction_delta = fernCascadeTest(image , current_shape , fernCascade , params )
53
- image.intermediate_bbx = getbbox(current_shape );
54
- meanshape = resetshape(image .intermediate_bbx , params .meanshape );
55
- image.tf2meanshape = fitgeotrans(bsxfun(@minus , current_shape , mean(current_shape )), ...
56
- bsxfun(@minus , meanshape , mean(meanshape )),...
57
- ' nonreflectivesimilarity' );
58
- image.meanshape2tf = fitgeotrans( bsxfun(@minus , meanshape , mean(meanshape )),...
59
- bsxfun(@minus , current_shape , mean(current_shape )), ...
60
- ' nonreflectivesimilarity' );
61
-
93
+ function delta_shape = fernCascadeTest(image , fernCascade , params , t )
62
94
% extract shape indexed pixels
63
95
candidate_pixel_location = fernCascade .candidate_pixel_location ;
64
96
nearest_landmark_index = fernCascade .nearest_landmark_index ;
65
97
intensities = zeros(1 , params .P );
66
98
for j = 1 : params .P
67
- x = candidate_pixel_location(j , 1 )*image .intermediate_bbx (3 );
68
- y = candidate_pixel_location(j , 2 )* image .intermediate_bbx (4 );
99
+ x = candidate_pixel_location(j , 1 )*image.intermediate_bboxes{ t } (3 );
100
+ y = candidate_pixel_location(j , 2 )* image.intermediate_bboxes{ t } (4 );
69
101
[project_x , project_y ] = transformPointsForward(image .meanshape2tf , x , y );
70
102
index = nearest_landmark_index(j );
71
103
72
- real_x = round(project_x + current_shape (index , 1 ));
73
- real_y = round(project_y + current_shape (index , 2 ));
104
+ real_x = round(project_x + image.intermediate_shapes{ t } (index , 1 ));
105
+ real_y = round(project_y + image.intermediate_shapes{ t } (index , 2 ));
74
106
real_x = max(1 , min(real_x , size(image .img_gray , 2 )-1 ));
75
107
real_y = max(1 , min(real_y , size(image .img_gray , 1 )-1 ));
76
108
intensities(j )= image .img_gray(real_y , real_x );
77
109
end
78
110
79
111
delta_shape = zeros(size(params .meanshape ));
80
- for i = 1 : params .K
81
- fern = fernCascade.ferns{i }.fern;
112
+ for k = 1 : params .K
113
+ fern = fernCascade.ferns{k }.fern;
82
114
delta_shape = delta_shape + fernTest(intensities , fern , params );
83
115
end
84
116
85
117
% convert to the currentshape model
86
- [u , v ] = transformPointsForward(image .meanshape2tf , delta_shape(: , 1 ), delta_shape(: , 2 ));
87
- prediction_delta = [u , v ];
88
- prediction_delta = bsxfun(@times , prediction_delta , [image .intermediate_bbx(3 ),image .intermediate_bbx(4 )]);
118
+ % [u, v] = transformPointsForward(image.meanshape2tf, delta_shape(:, 1), delta_shape(:, 2));
119
+ % prediction_delta = [u, v];
120
+ % prediction_delta = bsxfun(@times, prediction_delta, [image.intermediate_bbx(3),image.intermediate_bbx(4)]);
89
121
end
90
122
91
123
function fern_pred = fernTest(intensities , fern , params )
92
124
index = 0 ;
93
125
for i = 1 : params .F
94
- intensity_1 = intensities(fern .selected_pixel_index(i , 1 ));
95
- intensity_2 = intensities(fern .selected_pixel_index(i , 2 ));
126
+ m_f = fern .selected_pixel_index(i , 1 );
127
+ n_f = fern .selected_pixel_index(i , 2 );
128
+
129
+ intensity_1 = intensities(m_f );
130
+ intensity_2 = intensities(n_f );
96
131
97
- if ( intensity_1 - intensity_2 ) >= fern .threshold(i )
132
+ if intensity_1 - intensity_2 >= fern .threshold(i )
98
133
index = index + 2 ^(i - 1 );
99
134
end
100
135
end
0 commit comments