-
Notifications
You must be signed in to change notification settings - Fork 0
/
controller_without_error.m
99 lines (80 loc) · 2.4 KB
/
controller_without_error.m
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
function [xE,V,rate,spikes,error,forward_signal,rDec,iDec,failed] = controller_without_error(A,B,C,xT,dxT,Nneuron,dt,M,rDec,iDec,mu,nu,lambdaD,sigmaV,lambdaV)
% A system matrix
% B input matrix
% C output matrix
% xT Target
% Nneuron number neurons
% dt timestep
% M number of timesteps
% iDec instantanious Decoder
% rDec rate Decoder
% mu regularization L1
% nu regularization L2
% lambdaD decoder leak
% sigmaV noise
% lambdaV voltage leak
% ALLOW_LEARNING Activates the learning of connectivities
if lambdaV == 1
addition = 1;
else
addition = 0;
end
lambdaV = 0;
% Return
% xE Target Estimation
% V voltage of neurons over time
% rate filtered firing rate
% spikes Individual spike trains
% forward_signal feedforward signal
% rDec the learned FeedForward weights
% iDec the learned Recurrent weights
failed = false;
spikes = zeros(Nneuron,M);
V = zeros(Nneuron,M);
rate = zeros(Nneuron,M);
J = size(A,1); % Number of state variables
b = size(B,2); % Number of input variables
xE= zeros(J,M); % Network estimate
u = zeros(b,M); % control signal
forward_signal = zeros(J,M); % Feedforward signal
error = zeros(J,M); % tracking error
Thresh= (diag(iDec'*(B'*(C'*C)*B)*iDec) + lambdaD*nu + lambdaD^2*mu)/2;
Ws = -iDec'*(B'*(C'*C)*B)*rDec/lambdaD + 1*mu*lambdaD^2.*eye(Nneuron);
Wf = -iDec'*(B'*(C'*C)*B)*iDec - mu*lambdaD^2.*eye(Nneuron);
M2 = iDec'* B'*(C'*C);
for i = 1:M-1
noise = randn(Nneuron,1);
V(:,i+1) = (1-lambdaV*dt)*V(:,i)...
+ dt*M2*forward_signal(:,i)...
+ dt*Ws*rate(:,i)...
+ 0*Wf*spikes(:,i)...
+ sqrt(dt)*sigmaV*noise...
+ addition*dt*iDec'*B'*B*iDec*rate(:,i);
%%% MY ADDITION works only without
%%% Useless kind of
[m,k] = max(V(:,i+1)-Thresh);
s = 0;
while m>0
spikes(k,i+1) = spikes(k,i+1) + 1;
V(:,i+1) = V(:,i+1) + Wf(:,k);
[m,k] = max(V(:,i+1)-Thresh);
s = s+1;
if s > 5e6
failed = true;
return
end
end
if m>0 && 0
spikes(k,i+1) = 1;
V(:,i+1) = V(:,i+1) + Wf(:,k);
end
% WRONG!
% ff = V(:,i+1)>Thresh;
% spikes(ff,i+1) = 1;
rate(:,i+1) = (1-dt*lambdaD)*rate(:,i) + lambdaD*spikes(:,i+1);
u(:,i+1) = rDec/lambdaD*rate(:,i+1) + iDec/dt*spikes(:,i+1);
xE(:,i+1) = (eye(J) + dt*A)*xE(:,i) + dt*B*u(:,i+1);
forward_signal(:,i+1) = (dxT(:,i+1)-A*xE(:,i+1));
end
%plot(dt*(1:M),u)
end