-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #1 from ComputationalPsychiatry:v7.1.3
Pull request no. 13 from ilabcode (DAtanassova)
- Loading branch information
Showing
4 changed files
with
623 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,347 @@ | ||
function [traj, infStates] = tapas_ehgf_ar1_binary_mab(r, p, varargin) | ||
% Calculates the trajectories of the agent's representations under the HGF in a multi-armed bandit | ||
% situation with binary outcomes | ||
% | ||
% This function can be called in two ways: | ||
% | ||
% (1) tapas_ehgf_ar1_binary_mab(r, p) | ||
% | ||
% where r is the structure generated by tapas_fitModel and p is the parameter vector in native space; | ||
% | ||
% (2) tapas_ehgf_ar1_binary_mab(r, ptrans, 'trans') | ||
% | ||
% where r is the structure generated by tapas_fitModel, ptrans is the parameter vector in | ||
% transformed space, and 'trans' is a flag indicating this. | ||
% | ||
% -------------------------------------------------------------------------------------------------- | ||
% Copyright (C) 2017 Christoph Mathys, TNU, UZH & ETHZ | ||
% | ||
% This file is part of the HGF toolbox, which is released under the terms of the GNU General Public | ||
% Licence (GPL), version 3. You can redistribute it and/or modify it under the terms of the GPL | ||
% (either version 3 or, at your option, any later version). For further details, see the file | ||
% COPYING or <http://www.gnu.org/licenses/>. | ||
|
||
|
||
% Transform paramaters back to their native space if needed | ||
if ~isempty(varargin) && strcmp(varargin{1},'trans') | ||
p = tapas_ehgf_ar1_binary_mab_transp(r, p) % change to ehgf, remove ; | ||
end | ||
|
||
% Number of levels | ||
try | ||
l = r.c_prc.n_levels; | ||
catch | ||
l = (length(p)+1)/7; %change to 7 to include rho (EHGF) | ||
|
||
if l ~= floor(l) | ||
error('tapas:hgf:UndetNumLevels', 'Cannot determine number of levels'); | ||
end | ||
end | ||
|
||
% Number of bandits | ||
try | ||
b = r.c_prc.n_bandits; | ||
catch | ||
error('tapas:hgf:NumOfBanditsConfig', 'Number of bandits has to be configured in r.c_prc.n_bandits.'); | ||
end | ||
|
||
% Coupled updating | ||
% This is only allowed if there are 2 bandits. We here assume that the mu1hat for the two bandits | ||
% add to unity. | ||
coupled = false; | ||
if r.c_prc.coupled == true | ||
if b == 2 | ||
coupled = true; | ||
else | ||
error('tapas:hgf:HgfBinaryMab:CoupledOnlyForTwo', 'Coupled updating can only be configured for 2 bandits.'); | ||
end | ||
end | ||
|
||
% Unpack parameters | ||
mu_0 = p(1:l); | ||
sa_0 = p(l+1:2*l); | ||
phi = p(2*l+1:3*l); | ||
m = p(3*l+1:4*l); | ||
rho = p(4*l+1:5*l); % added rho | ||
ka = p(5*l+1:6*l-1); | ||
om = p(6*l:7*l-2); | ||
th = exp(p(7*l-1)); | ||
|
||
|
||
% Add dummy "zeroth" trial | ||
u = [0; r.u(:,1)]; | ||
try % For estimation | ||
y = [1; r.y(:,1)]; | ||
irr = r.irr; | ||
catch % For simulation | ||
y = [1; r.u(:,2)]; | ||
irr = find(isnan(r.u(:,2))); | ||
end | ||
|
||
% Number of trials (including prior) | ||
n = size(u,1); | ||
|
||
% Construct time axis | ||
if r.c_prc.irregular_intervals | ||
if size(u,2) > 1 | ||
t = [0; r.u(:,end)]; | ||
else | ||
error('tapas:hgf:InputSingleColumn', 'Input matrix must contain more than one column if irregular_intervals is set to true.'); | ||
end | ||
else | ||
t = ones(n,1); | ||
end | ||
|
||
% Initialize updated quantities | ||
|
||
% Representations | ||
mu = NaN(n,l,b); | ||
pi = NaN(n,l,b); | ||
|
||
% Other quantities | ||
muhat = NaN(n,l,b); | ||
pihat = NaN(n,l,b); | ||
v = NaN(n,l); | ||
w = NaN(n,l-1); | ||
da = NaN(n,l); | ||
|
||
% Representation priors | ||
% Note: first entries of the other quantities remain | ||
% NaN because they are undefined and are thrown away | ||
% at the end; their presence simply leads to consistent | ||
% trial indices. | ||
mu(1,1,:) = tapas_sgm(mu_0(2), 1); | ||
muhat(1,1,:) = mu(1,1,:); | ||
pihat(1,1,:) = 0; | ||
pi(1,1,:) = Inf; | ||
mu(1,2:end,:) = repmat(mu_0(2:end),[1 1 b]); | ||
pi(1,2:end,:) = repmat(1./sa_0(2:end),[1 1 b]); | ||
|
||
% Pass through representation update loop | ||
for k = 2:1:n | ||
if not(ismember(k-1, r.ign)) | ||
|
||
%%%%%%%%%%%%%%%%%%%%%% | ||
% Effect of input u(k) | ||
%%%%%%%%%%%%%%%%%%%%%% | ||
|
||
% 2nd level prediction | ||
muhat(k,2) = mu(k-1,2) +t(k) *rho(2) +t(k) *phi(2) *(m(2) -mu(k-1,2)); | ||
|
||
% 1st level | ||
% ~~~~~~~~~ | ||
% Prediction | ||
muhat(k,1,:) = tapas_sgm(ka(1) *muhat(k,2,:), 1); | ||
|
||
% Precision of prediction | ||
pihat(k,1,:) = 1/(muhat(k,1,:).*(1 -muhat(k,1,:))); | ||
|
||
% Updates | ||
pi(k,1,:) = pihat(k,1,:); | ||
pi(k,1,y(k)) = Inf; | ||
|
||
mu(k,1,:) = muhat(k,1,:); | ||
mu(k,1,y(k)) = u(k); | ||
|
||
% Prediction error | ||
da(k,1) = mu(k,1,y(k)) -muhat(k,1,y(k)); | ||
|
||
% 2nd level | ||
% ~~~~~~~~~ | ||
% Prediction: see above | ||
|
||
% Precision of prediction | ||
pihat(k,2,:) = 1/(1/pi(k-1,2,:) +exp(ka(2) *mu(k-1,3,:) +om(2))); | ||
|
||
% Updates | ||
pi(k,2,:) = pihat(k,2,:) +ka(1)^2/pihat(k,1,:); | ||
|
||
mu(k,2,:) = muhat(k,2,:); | ||
mu(k,2,y(k)) = muhat(k,2,y(k)) +ka(1)/pi(k,2,y(k)) *da(k,1); | ||
|
||
% Volatility prediction error | ||
da(k,2) = (1/pi(k,2,y(k)) +(mu(k,2,y(k)) -muhat(k,2,y(k)))^2) *pihat(k,2,y(k)) -1; | ||
|
||
if l > 3 | ||
% Pass through higher levels | ||
% ~~~~~~~~~~~~~~~~~~~~~~~~~~ | ||
for j = 3:l-1 | ||
% Prediction | ||
muhat(k,j,:) = mu(k-1,j,:) +t(k) *phi(j) *(m(j) -mu(k-1,j)); | ||
|
||
% Precision of prediction | ||
pihat(k,j,:) = 1/(1/pi(k-1,j,:) +t(k) *exp(ka(j) *mu(k-1,j+1,:) +om(j))); | ||
|
||
% Weighting factor | ||
v(k,j-1) = t(k) *exp(ka(j-1) *mu(k-1,j,y(k)) +om(j-1)); | ||
w(k,j-1) = v(k,j-1) *pihat(k,j-1,y(k)); | ||
|
||
|
||
% Mean Updates | ||
mu(k,j,:) = muhat(k,j) +1/2 *1/pihat(k,j) *ka(j-1) *w(k,j-1) *da(k,j-1); | ||
|
||
|
||
% Ingredients of precision update which depend on the mean | ||
% update | ||
vv = t(k) *exp(ka(j-1) *mu(k,j) +om(j-1)); | ||
pimhat = 1/(1/pi(k-1,j-1) +vv); | ||
ww = vv *pimhat; | ||
rr = (vv -1/pi(k-1,j-1)) *pimhat; | ||
dd = (1/pi(k,j-1) +(mu(k,j-1) -muhat(k,j-1))^2) *pimhat -1; | ||
|
||
% Precision update | ||
pi(k,j,:) = pihat(k,j,:) +max(0, 1/2 *ka(j-1)^2 *ww*(ww +rr*dd)); | ||
|
||
% Volatility prediction error | ||
da(k,j) = (1/pi(k,j,y(k)) +(mu(k,j,y(k)) -muhat(k,j,y(k)))^2) *pihat(k,j,y(k)) -1; | ||
end | ||
end | ||
|
||
% Last level | ||
% ~~~~~~~~~~ | ||
% Prediction | ||
muhat(k,l,:) = mu(k-1,l,:) +t(k) *rho(l) +t(k) *phi(l) *(m(l) -mu(k-1,l)); | ||
|
||
% Precision of prediction | ||
pihat(k,l,:) = 1/(1/pi(k-1,l,:) +t(k) *th); | ||
|
||
% Weighting factor | ||
v(k,l) = t(k) *th; | ||
v(k,l-1) = t(k) *exp(ka(l-1) *mu(k-1,l,y(k)) +om(l-1)); | ||
w(k,l-1) = v(k,l-1) *pihat(k,l-1,y(k)); | ||
|
||
% Mean updates | ||
mu(k,l,:) = muhat(k,l,:) +1/2 *1/pihat(k,l) *ka(l-1) *w(k,l-1) *da(k,l-1); | ||
|
||
|
||
% Ingredients of the precision update which depend on the mean | ||
% update | ||
vv = t(k) *exp(ka(l-1) *mu(k,l) +om(l-1)); | ||
pimhat = 1/(1/pi(k-1,l-1) +vv); | ||
ww = vv *pimhat; | ||
rr = (vv -1/pi(k-1,l-1)) *pimhat; | ||
dd = (1/pi(k,l-1) +(mu(k,l-1) -muhat(k,l-1))^2) *pimhat -1; | ||
|
||
pi(k,l,:) = pihat(k,l,:) +max(0, 1/2 *ka(l-1)^2 *ww*(ww +rr*dd)); | ||
|
||
|
||
% Volatility prediction error | ||
da(k,l) = (1/pi(k,l,y(k)) +(mu(k,l,y(k)) -muhat(k,l,y(k)))^2) *pihat(k,l,y(k)) -1; | ||
|
||
if coupled == true | ||
if y(k) == 1 | ||
mu(k,1,2) = 1 -mu(k,1,1); | ||
mu(k,2,2) = tapas_logit(1 -tapas_sgm(mu(k,2,1), 1), 1); | ||
elseif y(k) == 2 | ||
mu(k,1,1) = 1 -mu(k,1,2); | ||
mu(k,2,1) = tapas_logit(1 -tapas_sgm(mu(k,2,2), 1), 1); | ||
end | ||
end | ||
else | ||
|
||
mu(k,:,:) = mu(k-1,:,:); | ||
pi(k,:,:) = pi(k-1,:,:); | ||
|
||
muhat(k,:,:) = muhat(k-1,:,:); | ||
pihat(k,:,:) = pihat(k-1,:,:); | ||
|
||
v(k,:) = v(k-1,:); | ||
w(k,:) = w(k-1,:); | ||
da(k,:) = da(k-1,:); | ||
|
||
end | ||
end | ||
|
||
% Remove representation priors | ||
mu(1,:,:) = []; | ||
pi(1,:,:) = []; | ||
|
||
% Remove other dummy initial values | ||
muhat(1,:,:) = []; | ||
pihat(1,:,:) = []; | ||
v(1,:) = []; | ||
w(1,:) = []; | ||
da(1,:) = []; | ||
y(1) = []; | ||
|
||
% Responses on regular trials | ||
yreg = y; | ||
yreg(irr) =[]; | ||
|
||
% Implied learning rate at the first level | ||
mu2 = squeeze(mu(:,2,:)); | ||
mu2(irr,:) = []; | ||
mu2obs = mu2(sub2ind(size(mu2), (1:size(mu2,1))', yreg)); | ||
|
||
mu1hat = squeeze(muhat(:,1,:)); | ||
mu1hat(irr,:) = []; | ||
mu1hatobs = mu1hat(sub2ind(size(mu1hat), (1:size(mu1hat,1))', yreg)); | ||
|
||
upd1 = tapas_sgm(ka(1)*mu2obs,1) -mu1hatobs; | ||
|
||
dareg = da; | ||
dareg(irr,:) = []; | ||
|
||
lr1reg = upd1./dareg(:,1); | ||
lr1 = NaN(n-1,1); | ||
lr1(setdiff(1:n-1, irr)) = lr1reg; | ||
|
||
% Create result data structure | ||
traj = struct; | ||
|
||
traj.mu = mu; | ||
traj.sa = 1./pi; | ||
|
||
traj.muhat = muhat; | ||
traj.sahat = 1./pihat; | ||
|
||
traj.v = v; | ||
traj.w = w; | ||
traj.da = da; | ||
|
||
% Updates with respect to prediction | ||
traj.ud = mu -muhat; | ||
|
||
% Psi (precision weights on prediction errors) | ||
psi = NaN(n-1,l); | ||
|
||
pi2 = squeeze(pi(:,2,:)); | ||
pi2(irr,:) = []; | ||
pi2obs = pi2(sub2ind(size(pi2), (1:size(pi2,1))', yreg)); | ||
|
||
psi(setdiff(1:n-1, irr), 2) = 1./pi2obs; | ||
|
||
for i=3:l | ||
pihati = squeeze(pihat(:,i-1,:)); | ||
pihati(irr,:) = []; | ||
pihatiobs = pihati(sub2ind(size(pihati), (1:size(pihati,1))', yreg)); | ||
|
||
pii = squeeze(pi(:,i,:)); | ||
pii(irr,:) = []; | ||
piiobs = pii(sub2ind(size(pii), (1:size(pii,1))', yreg)); | ||
|
||
psi(setdiff(1:n-1, irr), i) = pihatiobs./piiobs; | ||
end | ||
|
||
traj.psi = psi; | ||
|
||
% Epsilons (precision-weighted prediction errors) | ||
epsi = NaN(n-1,l); | ||
epsi(:,2:l) = psi(:,2:l) .*da(:,1:l-1); | ||
traj.epsi = epsi; | ||
|
||
% Full learning rate (full weights on prediction errors) | ||
wt = NaN(n-1,l); | ||
wt(:,1) = lr1; | ||
wt(:,2) = psi(:,2); | ||
wt(:,3:l) = 1/2 *(v(:,2:l-1) *diag(ka(2:l-1))) .*psi(:,3:l); | ||
traj.wt = wt; | ||
|
||
% Create matrices for use by the observation model | ||
infStates = NaN(n-1,l,b,4); | ||
infStates(:,:,:,1) = traj.muhat; | ||
infStates(:,:,:,2) = traj.sahat; | ||
infStates(:,:,:,3) = traj.mu; | ||
infStates(:,:,:,4) = traj.sa; | ||
|
||
end |
Oops, something went wrong.