-
Notifications
You must be signed in to change notification settings - Fork 1
/
test_model_SVM_rho02.m
91 lines (75 loc) · 2.05 KB
/
test_model_SVM_rho02.m
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
%% classification on ranData_rho02
load data/ranData_rho02;
%% basic parameter
[p,N] = size(Xtrain);
w_init = randn(p,1);
b_init = 0;
t_init = zeros(N,1);
%% Testing by student code
fprintf('Testing by student code\n\n');
% parameters
lam = 0.1;
opts = [];
opts.tol = 1e-3;
opts.maxit = 5000;
opts.subtol = 1e-3;
opts.maxsubit = 10000;
opts.beta = 1;
opts.w0 = w_init;
opts.b0 = b_init;
opts.t0 = t_init;
% train
t0 = tic;
[w_s,b_s,out_s] = ADMM_SVM(Xtrain, ytrain, lam, opts);
time = toc(t0);
% do classification on the testing data
pred_y = sign(Xtest'*w_s + b_s);
accu = sum(pred_y==ytest)/length(ytest);
% print results
fprintf('Running time is %5.4f\n',time);
fprintf('classification accuracy on testing data: %4.2f%%\n\n',accu*100);
fig = figure('papersize',[5,4],'paperposition',[0,0,5,4]);
semilogy(out_s.hist_pres,'b-','linewidth',2);
hold on;
semilogy(out_s.hist_dres,'r-','linewidth',2);
legend('Primal residual','dual residual','location','best');
xlabel('outer iteration');
ylabel('error');
title('student: ranData\_rho02');
set(gca,'fontsize',14)
hold off;
print(fig,'-dpdf','ranData_rho02_student');
%% Testing by instructor code
fprintf('Testing by instructor code\n\n');
% parameters
lam = 0.1;
opts = [];
opts.tol = 1e-3;
opts.maxit = 1000;
opts.subtol = 1e-4;
opts.maxsubit = 10000;
opts.beta = 1;
opts.w0 = w_init;
opts.b0 = b_init;
opts.t0 = t_init;
% train
t0 = tic;
[w_p,b_p,out_p] = ALM_SVM_p(Xtrain,ytrain,lam,opts);
time = toc(t0);
% do classification on the testing data
pred_y = sign(Xtest'*w_p + b_p);
accu = sum(pred_y==ytest)/length(ytest);
% print results
fprintf('Running time is %5.4f\n',time);
fprintf('classification accuracy on testing data: %4.2f%%\n\n',accu*100);
fig = figure('papersize',[5,4],'paperposition',[0,0,5,4]);
semilogy(out_p.hist_pres,'b-','linewidth',2);
hold on;
semilogy(out_p.hist_dres,'r-','linewidth',2);
legend('Primal residual','dual residual','location','best');
xlabel('outer iteration');
ylabel('error');
title('instructor: ranData\_rho02');
set(gca,'fontsize',14);
hold off;
print(fig,'-dpdf','ranData_rho02_instructor');