Skip to content

Commit 81159a9

Browse files
committed
Pull request no. 13 from ilabcode (DAtanassova)
1 parent 206198a commit 81159a9

File tree

4 files changed

+623
-0
lines changed

4 files changed

+623
-0
lines changed

tapas_ehgf_ar1_binary_mab.m

Lines changed: 347 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,347 @@
1+
function [traj, infStates] = tapas_ehgf_ar1_binary_mab(r, p, varargin)
2+
% Calculates the trajectories of the agent's representations under the HGF in a multi-armed bandit
3+
% situation with binary outcomes
4+
%
5+
% This function can be called in two ways:
6+
%
7+
% (1) tapas_ehgf_ar1_binary_mab(r, p)
8+
%
9+
% where r is the structure generated by tapas_fitModel and p is the parameter vector in native space;
10+
%
11+
% (2) tapas_ehgf_ar1_binary_mab(r, ptrans, 'trans')
12+
%
13+
% where r is the structure generated by tapas_fitModel, ptrans is the parameter vector in
14+
% transformed space, and 'trans' is a flag indicating this.
15+
%
16+
% --------------------------------------------------------------------------------------------------
17+
% Copyright (C) 2017 Christoph Mathys, TNU, UZH & ETHZ
18+
%
19+
% This file is part of the HGF toolbox, which is released under the terms of the GNU General Public
20+
% Licence (GPL), version 3. You can redistribute it and/or modify it under the terms of the GPL
21+
% (either version 3 or, at your option, any later version). For further details, see the file
22+
% COPYING or <http://www.gnu.org/licenses/>.
23+
24+
25+
% Transform paramaters back to their native space if needed
26+
if ~isempty(varargin) && strcmp(varargin{1},'trans')
27+
p = tapas_ehgf_ar1_binary_mab_transp(r, p) % change to ehgf, remove ;
28+
end
29+
30+
% Number of levels
31+
try
32+
l = r.c_prc.n_levels;
33+
catch
34+
l = (length(p)+1)/7; %change to 7 to include rho (EHGF)
35+
36+
if l ~= floor(l)
37+
error('tapas:hgf:UndetNumLevels', 'Cannot determine number of levels');
38+
end
39+
end
40+
41+
% Number of bandits
42+
try
43+
b = r.c_prc.n_bandits;
44+
catch
45+
error('tapas:hgf:NumOfBanditsConfig', 'Number of bandits has to be configured in r.c_prc.n_bandits.');
46+
end
47+
48+
% Coupled updating
49+
% This is only allowed if there are 2 bandits. We here assume that the mu1hat for the two bandits
50+
% add to unity.
51+
coupled = false;
52+
if r.c_prc.coupled == true
53+
if b == 2
54+
coupled = true;
55+
else
56+
error('tapas:hgf:HgfBinaryMab:CoupledOnlyForTwo', 'Coupled updating can only be configured for 2 bandits.');
57+
end
58+
end
59+
60+
% Unpack parameters
61+
mu_0 = p(1:l);
62+
sa_0 = p(l+1:2*l);
63+
phi = p(2*l+1:3*l);
64+
m = p(3*l+1:4*l);
65+
rho = p(4*l+1:5*l); % added rho
66+
ka = p(5*l+1:6*l-1);
67+
om = p(6*l:7*l-2);
68+
th = exp(p(7*l-1));
69+
70+
71+
% Add dummy "zeroth" trial
72+
u = [0; r.u(:,1)];
73+
try % For estimation
74+
y = [1; r.y(:,1)];
75+
irr = r.irr;
76+
catch % For simulation
77+
y = [1; r.u(:,2)];
78+
irr = find(isnan(r.u(:,2)));
79+
end
80+
81+
% Number of trials (including prior)
82+
n = size(u,1);
83+
84+
% Construct time axis
85+
if r.c_prc.irregular_intervals
86+
if size(u,2) > 1
87+
t = [0; r.u(:,end)];
88+
else
89+
error('tapas:hgf:InputSingleColumn', 'Input matrix must contain more than one column if irregular_intervals is set to true.');
90+
end
91+
else
92+
t = ones(n,1);
93+
end
94+
95+
% Initialize updated quantities
96+
97+
% Representations
98+
mu = NaN(n,l,b);
99+
pi = NaN(n,l,b);
100+
101+
% Other quantities
102+
muhat = NaN(n,l,b);
103+
pihat = NaN(n,l,b);
104+
v = NaN(n,l);
105+
w = NaN(n,l-1);
106+
da = NaN(n,l);
107+
108+
% Representation priors
109+
% Note: first entries of the other quantities remain
110+
% NaN because they are undefined and are thrown away
111+
% at the end; their presence simply leads to consistent
112+
% trial indices.
113+
mu(1,1,:) = tapas_sgm(mu_0(2), 1);
114+
muhat(1,1,:) = mu(1,1,:);
115+
pihat(1,1,:) = 0;
116+
pi(1,1,:) = Inf;
117+
mu(1,2:end,:) = repmat(mu_0(2:end),[1 1 b]);
118+
pi(1,2:end,:) = repmat(1./sa_0(2:end),[1 1 b]);
119+
120+
% Pass through representation update loop
121+
for k = 2:1:n
122+
if not(ismember(k-1, r.ign))
123+
124+
%%%%%%%%%%%%%%%%%%%%%%
125+
% Effect of input u(k)
126+
%%%%%%%%%%%%%%%%%%%%%%
127+
128+
% 2nd level prediction
129+
muhat(k,2) = mu(k-1,2) +t(k) *rho(2) +t(k) *phi(2) *(m(2) -mu(k-1,2));
130+
131+
% 1st level
132+
% ~~~~~~~~~
133+
% Prediction
134+
muhat(k,1,:) = tapas_sgm(ka(1) *muhat(k,2,:), 1);
135+
136+
% Precision of prediction
137+
pihat(k,1,:) = 1/(muhat(k,1,:).*(1 -muhat(k,1,:)));
138+
139+
% Updates
140+
pi(k,1,:) = pihat(k,1,:);
141+
pi(k,1,y(k)) = Inf;
142+
143+
mu(k,1,:) = muhat(k,1,:);
144+
mu(k,1,y(k)) = u(k);
145+
146+
% Prediction error
147+
da(k,1) = mu(k,1,y(k)) -muhat(k,1,y(k));
148+
149+
% 2nd level
150+
% ~~~~~~~~~
151+
% Prediction: see above
152+
153+
% Precision of prediction
154+
pihat(k,2,:) = 1/(1/pi(k-1,2,:) +exp(ka(2) *mu(k-1,3,:) +om(2)));
155+
156+
% Updates
157+
pi(k,2,:) = pihat(k,2,:) +ka(1)^2/pihat(k,1,:);
158+
159+
mu(k,2,:) = muhat(k,2,:);
160+
mu(k,2,y(k)) = muhat(k,2,y(k)) +ka(1)/pi(k,2,y(k)) *da(k,1);
161+
162+
% Volatility prediction error
163+
da(k,2) = (1/pi(k,2,y(k)) +(mu(k,2,y(k)) -muhat(k,2,y(k)))^2) *pihat(k,2,y(k)) -1;
164+
165+
if l > 3
166+
% Pass through higher levels
167+
% ~~~~~~~~~~~~~~~~~~~~~~~~~~
168+
for j = 3:l-1
169+
% Prediction
170+
muhat(k,j,:) = mu(k-1,j,:) +t(k) *phi(j) *(m(j) -mu(k-1,j));
171+
172+
% Precision of prediction
173+
pihat(k,j,:) = 1/(1/pi(k-1,j,:) +t(k) *exp(ka(j) *mu(k-1,j+1,:) +om(j)));
174+
175+
% Weighting factor
176+
v(k,j-1) = t(k) *exp(ka(j-1) *mu(k-1,j,y(k)) +om(j-1));
177+
w(k,j-1) = v(k,j-1) *pihat(k,j-1,y(k));
178+
179+
180+
% Mean Updates
181+
mu(k,j,:) = muhat(k,j) +1/2 *1/pihat(k,j) *ka(j-1) *w(k,j-1) *da(k,j-1);
182+
183+
184+
% Ingredients of precision update which depend on the mean
185+
% update
186+
vv = t(k) *exp(ka(j-1) *mu(k,j) +om(j-1));
187+
pimhat = 1/(1/pi(k-1,j-1) +vv);
188+
ww = vv *pimhat;
189+
rr = (vv -1/pi(k-1,j-1)) *pimhat;
190+
dd = (1/pi(k,j-1) +(mu(k,j-1) -muhat(k,j-1))^2) *pimhat -1;
191+
192+
% Precision update
193+
pi(k,j,:) = pihat(k,j,:) +max(0, 1/2 *ka(j-1)^2 *ww*(ww +rr*dd));
194+
195+
% Volatility prediction error
196+
da(k,j) = (1/pi(k,j,y(k)) +(mu(k,j,y(k)) -muhat(k,j,y(k)))^2) *pihat(k,j,y(k)) -1;
197+
end
198+
end
199+
200+
% Last level
201+
% ~~~~~~~~~~
202+
% Prediction
203+
muhat(k,l,:) = mu(k-1,l,:) +t(k) *rho(l) +t(k) *phi(l) *(m(l) -mu(k-1,l));
204+
205+
% Precision of prediction
206+
pihat(k,l,:) = 1/(1/pi(k-1,l,:) +t(k) *th);
207+
208+
% Weighting factor
209+
v(k,l) = t(k) *th;
210+
v(k,l-1) = t(k) *exp(ka(l-1) *mu(k-1,l,y(k)) +om(l-1));
211+
w(k,l-1) = v(k,l-1) *pihat(k,l-1,y(k));
212+
213+
% Mean updates
214+
mu(k,l,:) = muhat(k,l,:) +1/2 *1/pihat(k,l) *ka(l-1) *w(k,l-1) *da(k,l-1);
215+
216+
217+
% Ingredients of the precision update which depend on the mean
218+
% update
219+
vv = t(k) *exp(ka(l-1) *mu(k,l) +om(l-1));
220+
pimhat = 1/(1/pi(k-1,l-1) +vv);
221+
ww = vv *pimhat;
222+
rr = (vv -1/pi(k-1,l-1)) *pimhat;
223+
dd = (1/pi(k,l-1) +(mu(k,l-1) -muhat(k,l-1))^2) *pimhat -1;
224+
225+
pi(k,l,:) = pihat(k,l,:) +max(0, 1/2 *ka(l-1)^2 *ww*(ww +rr*dd));
226+
227+
228+
% Volatility prediction error
229+
da(k,l) = (1/pi(k,l,y(k)) +(mu(k,l,y(k)) -muhat(k,l,y(k)))^2) *pihat(k,l,y(k)) -1;
230+
231+
if coupled == true
232+
if y(k) == 1
233+
mu(k,1,2) = 1 -mu(k,1,1);
234+
mu(k,2,2) = tapas_logit(1 -tapas_sgm(mu(k,2,1), 1), 1);
235+
elseif y(k) == 2
236+
mu(k,1,1) = 1 -mu(k,1,2);
237+
mu(k,2,1) = tapas_logit(1 -tapas_sgm(mu(k,2,2), 1), 1);
238+
end
239+
end
240+
else
241+
242+
mu(k,:,:) = mu(k-1,:,:);
243+
pi(k,:,:) = pi(k-1,:,:);
244+
245+
muhat(k,:,:) = muhat(k-1,:,:);
246+
pihat(k,:,:) = pihat(k-1,:,:);
247+
248+
v(k,:) = v(k-1,:);
249+
w(k,:) = w(k-1,:);
250+
da(k,:) = da(k-1,:);
251+
252+
end
253+
end
254+
255+
% Remove representation priors
256+
mu(1,:,:) = [];
257+
pi(1,:,:) = [];
258+
259+
% Remove other dummy initial values
260+
muhat(1,:,:) = [];
261+
pihat(1,:,:) = [];
262+
v(1,:) = [];
263+
w(1,:) = [];
264+
da(1,:) = [];
265+
y(1) = [];
266+
267+
% Responses on regular trials
268+
yreg = y;
269+
yreg(irr) =[];
270+
271+
% Implied learning rate at the first level
272+
mu2 = squeeze(mu(:,2,:));
273+
mu2(irr,:) = [];
274+
mu2obs = mu2(sub2ind(size(mu2), (1:size(mu2,1))', yreg));
275+
276+
mu1hat = squeeze(muhat(:,1,:));
277+
mu1hat(irr,:) = [];
278+
mu1hatobs = mu1hat(sub2ind(size(mu1hat), (1:size(mu1hat,1))', yreg));
279+
280+
upd1 = tapas_sgm(ka(1)*mu2obs,1) -mu1hatobs;
281+
282+
dareg = da;
283+
dareg(irr,:) = [];
284+
285+
lr1reg = upd1./dareg(:,1);
286+
lr1 = NaN(n-1,1);
287+
lr1(setdiff(1:n-1, irr)) = lr1reg;
288+
289+
% Create result data structure
290+
traj = struct;
291+
292+
traj.mu = mu;
293+
traj.sa = 1./pi;
294+
295+
traj.muhat = muhat;
296+
traj.sahat = 1./pihat;
297+
298+
traj.v = v;
299+
traj.w = w;
300+
traj.da = da;
301+
302+
% Updates with respect to prediction
303+
traj.ud = mu -muhat;
304+
305+
% Psi (precision weights on prediction errors)
306+
psi = NaN(n-1,l);
307+
308+
pi2 = squeeze(pi(:,2,:));
309+
pi2(irr,:) = [];
310+
pi2obs = pi2(sub2ind(size(pi2), (1:size(pi2,1))', yreg));
311+
312+
psi(setdiff(1:n-1, irr), 2) = 1./pi2obs;
313+
314+
for i=3:l
315+
pihati = squeeze(pihat(:,i-1,:));
316+
pihati(irr,:) = [];
317+
pihatiobs = pihati(sub2ind(size(pihati), (1:size(pihati,1))', yreg));
318+
319+
pii = squeeze(pi(:,i,:));
320+
pii(irr,:) = [];
321+
piiobs = pii(sub2ind(size(pii), (1:size(pii,1))', yreg));
322+
323+
psi(setdiff(1:n-1, irr), i) = pihatiobs./piiobs;
324+
end
325+
326+
traj.psi = psi;
327+
328+
% Epsilons (precision-weighted prediction errors)
329+
epsi = NaN(n-1,l);
330+
epsi(:,2:l) = psi(:,2:l) .*da(:,1:l-1);
331+
traj.epsi = epsi;
332+
333+
% Full learning rate (full weights on prediction errors)
334+
wt = NaN(n-1,l);
335+
wt(:,1) = lr1;
336+
wt(:,2) = psi(:,2);
337+
wt(:,3:l) = 1/2 *(v(:,2:l-1) *diag(ka(2:l-1))) .*psi(:,3:l);
338+
traj.wt = wt;
339+
340+
% Create matrices for use by the observation model
341+
infStates = NaN(n-1,l,b,4);
342+
infStates(:,:,:,1) = traj.muhat;
343+
infStates(:,:,:,2) = traj.sahat;
344+
infStates(:,:,:,3) = traj.mu;
345+
infStates(:,:,:,4) = traj.sa;
346+
347+
end

0 commit comments

Comments
 (0)