-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathdsvdd.m
64 lines (49 loc) · 1.42 KB
/
dsvdd.m
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
%
% Deep SVDD
%
% W = DSVDD(A,FRACREJ,N)
%
%
% INPUT
% A Dataset
% FRACREJ Fraction of target objects rejected (default = 0.1)
% N Number of hidden units (default = 5)
%
% OUTPUT
% W Deep SVDD network
%
% DESCRIPTION
% Train a Deep SVDD network with N hidden units.
%
%
function W = dsvdd(varargin)
argin = shiftargin(varargin,'scalar');
argin = setdefaults(argin,[],0.1,5);
if mapping_task(argin,'definition') % empty mapping
W = define_mapping(argin,'untrained','DSVDD');
elseif mapping_task(argin,'training')
[a,fracrej,N] = deal(argin{:});
a = +target_class(a); % make sure a is an OC dataset
[nrx,dim] = size(a);
net = py.main.start(a(:), nrx, dim, N);
d = transpose(double(py.main.predict(net, a(:), nrx, dim)));
% obtain the threshold:
W.threshold = dd_threshold(d,1-fracrej);
%and save all useful data:
W.net = net;
W.scale = mean(d);
W = prmapping(mfilename,'trained',W,str2mat('target','outlier'),dim,2);
W = setname(W,'DSVDD (N=%d)',N);
elseif mapping_task(argin,'trained execution') %testing
[a,fracrej] = deal(argin{1:2});
W = getdata(fracrej); % unpack
[m, d] = size(a);
%compute distance:
out = transpose(double(py.main.predict(W.net, +a(:), m, d)));
out = [out repmat(W.threshold,m,1)];
%store the distance as output:
W = setdat(a,-out,fracrej);
W = setfeatdom(W,{[-inf 0;-inf 0] [-inf 0;-inf 0]});
else
error('Illegal call to DSVDD.');
end