|
| 1 | +function [traj, infStates] = tapas_ehgf_ar1_binary_mab(r, p, varargin) |
| 2 | +% Calculates the trajectories of the agent's representations under the HGF in a multi-armed bandit |
| 3 | +% situation with binary outcomes |
| 4 | +% |
| 5 | +% This function can be called in two ways: |
| 6 | +% |
| 7 | +% (1) tapas_ehgf_ar1_binary_mab(r, p) |
| 8 | +% |
| 9 | +% where r is the structure generated by tapas_fitModel and p is the parameter vector in native space; |
| 10 | +% |
| 11 | +% (2) tapas_ehgf_ar1_binary_mab(r, ptrans, 'trans') |
| 12 | +% |
| 13 | +% where r is the structure generated by tapas_fitModel, ptrans is the parameter vector in |
| 14 | +% transformed space, and 'trans' is a flag indicating this. |
| 15 | +% |
| 16 | +% -------------------------------------------------------------------------------------------------- |
| 17 | +% Copyright (C) 2017 Christoph Mathys, TNU, UZH & ETHZ |
| 18 | +% |
| 19 | +% This file is part of the HGF toolbox, which is released under the terms of the GNU General Public |
| 20 | +% Licence (GPL), version 3. You can redistribute it and/or modify it under the terms of the GPL |
| 21 | +% (either version 3 or, at your option, any later version). For further details, see the file |
| 22 | +% COPYING or <http://www.gnu.org/licenses/>. |
| 23 | + |
| 24 | + |
| 25 | +% Transform paramaters back to their native space if needed |
| 26 | +if ~isempty(varargin) && strcmp(varargin{1},'trans') |
| 27 | + p = tapas_ehgf_ar1_binary_mab_transp(r, p) % change to ehgf, remove ; |
| 28 | +end |
| 29 | + |
| 30 | +% Number of levels |
| 31 | +try |
| 32 | + l = r.c_prc.n_levels; |
| 33 | +catch |
| 34 | + l = (length(p)+1)/7; %change to 7 to include rho (EHGF) |
| 35 | + |
| 36 | + if l ~= floor(l) |
| 37 | + error('tapas:hgf:UndetNumLevels', 'Cannot determine number of levels'); |
| 38 | + end |
| 39 | +end |
| 40 | + |
| 41 | +% Number of bandits |
| 42 | +try |
| 43 | + b = r.c_prc.n_bandits; |
| 44 | +catch |
| 45 | + error('tapas:hgf:NumOfBanditsConfig', 'Number of bandits has to be configured in r.c_prc.n_bandits.'); |
| 46 | +end |
| 47 | + |
| 48 | +% Coupled updating |
| 49 | +% This is only allowed if there are 2 bandits. We here assume that the mu1hat for the two bandits |
| 50 | +% add to unity. |
| 51 | +coupled = false; |
| 52 | +if r.c_prc.coupled == true |
| 53 | + if b == 2 |
| 54 | + coupled = true; |
| 55 | + else |
| 56 | + error('tapas:hgf:HgfBinaryMab:CoupledOnlyForTwo', 'Coupled updating can only be configured for 2 bandits.'); |
| 57 | + end |
| 58 | +end |
| 59 | + |
| 60 | +% Unpack parameters |
| 61 | +mu_0 = p(1:l); |
| 62 | +sa_0 = p(l+1:2*l); |
| 63 | +phi = p(2*l+1:3*l); |
| 64 | +m = p(3*l+1:4*l); |
| 65 | +rho = p(4*l+1:5*l); % added rho |
| 66 | +ka = p(5*l+1:6*l-1); |
| 67 | +om = p(6*l:7*l-2); |
| 68 | +th = exp(p(7*l-1)); |
| 69 | + |
| 70 | + |
| 71 | +% Add dummy "zeroth" trial |
| 72 | +u = [0; r.u(:,1)]; |
| 73 | +try % For estimation |
| 74 | + y = [1; r.y(:,1)]; |
| 75 | + irr = r.irr; |
| 76 | +catch % For simulation |
| 77 | + y = [1; r.u(:,2)]; |
| 78 | + irr = find(isnan(r.u(:,2))); |
| 79 | +end |
| 80 | + |
| 81 | +% Number of trials (including prior) |
| 82 | +n = size(u,1); |
| 83 | + |
| 84 | +% Construct time axis |
| 85 | +if r.c_prc.irregular_intervals |
| 86 | + if size(u,2) > 1 |
| 87 | + t = [0; r.u(:,end)]; |
| 88 | + else |
| 89 | + error('tapas:hgf:InputSingleColumn', 'Input matrix must contain more than one column if irregular_intervals is set to true.'); |
| 90 | + end |
| 91 | +else |
| 92 | + t = ones(n,1); |
| 93 | +end |
| 94 | + |
| 95 | +% Initialize updated quantities |
| 96 | + |
| 97 | +% Representations |
| 98 | +mu = NaN(n,l,b); |
| 99 | +pi = NaN(n,l,b); |
| 100 | + |
| 101 | +% Other quantities |
| 102 | +muhat = NaN(n,l,b); |
| 103 | +pihat = NaN(n,l,b); |
| 104 | +v = NaN(n,l); |
| 105 | +w = NaN(n,l-1); |
| 106 | +da = NaN(n,l); |
| 107 | + |
| 108 | +% Representation priors |
| 109 | +% Note: first entries of the other quantities remain |
| 110 | +% NaN because they are undefined and are thrown away |
| 111 | +% at the end; their presence simply leads to consistent |
| 112 | +% trial indices. |
| 113 | +mu(1,1,:) = tapas_sgm(mu_0(2), 1); |
| 114 | +muhat(1,1,:) = mu(1,1,:); |
| 115 | +pihat(1,1,:) = 0; |
| 116 | +pi(1,1,:) = Inf; |
| 117 | +mu(1,2:end,:) = repmat(mu_0(2:end),[1 1 b]); |
| 118 | +pi(1,2:end,:) = repmat(1./sa_0(2:end),[1 1 b]); |
| 119 | + |
| 120 | +% Pass through representation update loop |
| 121 | +for k = 2:1:n |
| 122 | + if not(ismember(k-1, r.ign)) |
| 123 | + |
| 124 | + %%%%%%%%%%%%%%%%%%%%%% |
| 125 | + % Effect of input u(k) |
| 126 | + %%%%%%%%%%%%%%%%%%%%%% |
| 127 | + |
| 128 | + % 2nd level prediction |
| 129 | + muhat(k,2) = mu(k-1,2) +t(k) *rho(2) +t(k) *phi(2) *(m(2) -mu(k-1,2)); |
| 130 | + |
| 131 | + % 1st level |
| 132 | + % ~~~~~~~~~ |
| 133 | + % Prediction |
| 134 | + muhat(k,1,:) = tapas_sgm(ka(1) *muhat(k,2,:), 1); |
| 135 | + |
| 136 | + % Precision of prediction |
| 137 | + pihat(k,1,:) = 1/(muhat(k,1,:).*(1 -muhat(k,1,:))); |
| 138 | + |
| 139 | + % Updates |
| 140 | + pi(k,1,:) = pihat(k,1,:); |
| 141 | + pi(k,1,y(k)) = Inf; |
| 142 | + |
| 143 | + mu(k,1,:) = muhat(k,1,:); |
| 144 | + mu(k,1,y(k)) = u(k); |
| 145 | + |
| 146 | + % Prediction error |
| 147 | + da(k,1) = mu(k,1,y(k)) -muhat(k,1,y(k)); |
| 148 | + |
| 149 | + % 2nd level |
| 150 | + % ~~~~~~~~~ |
| 151 | + % Prediction: see above |
| 152 | + |
| 153 | + % Precision of prediction |
| 154 | + pihat(k,2,:) = 1/(1/pi(k-1,2,:) +exp(ka(2) *mu(k-1,3,:) +om(2))); |
| 155 | + |
| 156 | + % Updates |
| 157 | + pi(k,2,:) = pihat(k,2,:) +ka(1)^2/pihat(k,1,:); |
| 158 | + |
| 159 | + mu(k,2,:) = muhat(k,2,:); |
| 160 | + mu(k,2,y(k)) = muhat(k,2,y(k)) +ka(1)/pi(k,2,y(k)) *da(k,1); |
| 161 | + |
| 162 | + % Volatility prediction error |
| 163 | + da(k,2) = (1/pi(k,2,y(k)) +(mu(k,2,y(k)) -muhat(k,2,y(k)))^2) *pihat(k,2,y(k)) -1; |
| 164 | + |
| 165 | + if l > 3 |
| 166 | + % Pass through higher levels |
| 167 | + % ~~~~~~~~~~~~~~~~~~~~~~~~~~ |
| 168 | + for j = 3:l-1 |
| 169 | + % Prediction |
| 170 | + muhat(k,j,:) = mu(k-1,j,:) +t(k) *phi(j) *(m(j) -mu(k-1,j)); |
| 171 | + |
| 172 | + % Precision of prediction |
| 173 | + pihat(k,j,:) = 1/(1/pi(k-1,j,:) +t(k) *exp(ka(j) *mu(k-1,j+1,:) +om(j))); |
| 174 | + |
| 175 | + % Weighting factor |
| 176 | + v(k,j-1) = t(k) *exp(ka(j-1) *mu(k-1,j,y(k)) +om(j-1)); |
| 177 | + w(k,j-1) = v(k,j-1) *pihat(k,j-1,y(k)); |
| 178 | + |
| 179 | + |
| 180 | + % Mean Updates |
| 181 | + mu(k,j,:) = muhat(k,j) +1/2 *1/pihat(k,j) *ka(j-1) *w(k,j-1) *da(k,j-1); |
| 182 | + |
| 183 | + |
| 184 | + % Ingredients of precision update which depend on the mean |
| 185 | + % update |
| 186 | + vv = t(k) *exp(ka(j-1) *mu(k,j) +om(j-1)); |
| 187 | + pimhat = 1/(1/pi(k-1,j-1) +vv); |
| 188 | + ww = vv *pimhat; |
| 189 | + rr = (vv -1/pi(k-1,j-1)) *pimhat; |
| 190 | + dd = (1/pi(k,j-1) +(mu(k,j-1) -muhat(k,j-1))^2) *pimhat -1; |
| 191 | + |
| 192 | + % Precision update |
| 193 | + pi(k,j,:) = pihat(k,j,:) +max(0, 1/2 *ka(j-1)^2 *ww*(ww +rr*dd)); |
| 194 | + |
| 195 | + % Volatility prediction error |
| 196 | + da(k,j) = (1/pi(k,j,y(k)) +(mu(k,j,y(k)) -muhat(k,j,y(k)))^2) *pihat(k,j,y(k)) -1; |
| 197 | + end |
| 198 | + end |
| 199 | + |
| 200 | + % Last level |
| 201 | + % ~~~~~~~~~~ |
| 202 | + % Prediction |
| 203 | + muhat(k,l,:) = mu(k-1,l,:) +t(k) *rho(l) +t(k) *phi(l) *(m(l) -mu(k-1,l)); |
| 204 | + |
| 205 | + % Precision of prediction |
| 206 | + pihat(k,l,:) = 1/(1/pi(k-1,l,:) +t(k) *th); |
| 207 | + |
| 208 | + % Weighting factor |
| 209 | + v(k,l) = t(k) *th; |
| 210 | + v(k,l-1) = t(k) *exp(ka(l-1) *mu(k-1,l,y(k)) +om(l-1)); |
| 211 | + w(k,l-1) = v(k,l-1) *pihat(k,l-1,y(k)); |
| 212 | + |
| 213 | + % Mean updates |
| 214 | + mu(k,l,:) = muhat(k,l,:) +1/2 *1/pihat(k,l) *ka(l-1) *w(k,l-1) *da(k,l-1); |
| 215 | + |
| 216 | + |
| 217 | + % Ingredients of the precision update which depend on the mean |
| 218 | + % update |
| 219 | + vv = t(k) *exp(ka(l-1) *mu(k,l) +om(l-1)); |
| 220 | + pimhat = 1/(1/pi(k-1,l-1) +vv); |
| 221 | + ww = vv *pimhat; |
| 222 | + rr = (vv -1/pi(k-1,l-1)) *pimhat; |
| 223 | + dd = (1/pi(k,l-1) +(mu(k,l-1) -muhat(k,l-1))^2) *pimhat -1; |
| 224 | + |
| 225 | + pi(k,l,:) = pihat(k,l,:) +max(0, 1/2 *ka(l-1)^2 *ww*(ww +rr*dd)); |
| 226 | + |
| 227 | + |
| 228 | + % Volatility prediction error |
| 229 | + da(k,l) = (1/pi(k,l,y(k)) +(mu(k,l,y(k)) -muhat(k,l,y(k)))^2) *pihat(k,l,y(k)) -1; |
| 230 | + |
| 231 | + if coupled == true |
| 232 | + if y(k) == 1 |
| 233 | + mu(k,1,2) = 1 -mu(k,1,1); |
| 234 | + mu(k,2,2) = tapas_logit(1 -tapas_sgm(mu(k,2,1), 1), 1); |
| 235 | + elseif y(k) == 2 |
| 236 | + mu(k,1,1) = 1 -mu(k,1,2); |
| 237 | + mu(k,2,1) = tapas_logit(1 -tapas_sgm(mu(k,2,2), 1), 1); |
| 238 | + end |
| 239 | + end |
| 240 | + else |
| 241 | + |
| 242 | + mu(k,:,:) = mu(k-1,:,:); |
| 243 | + pi(k,:,:) = pi(k-1,:,:); |
| 244 | + |
| 245 | + muhat(k,:,:) = muhat(k-1,:,:); |
| 246 | + pihat(k,:,:) = pihat(k-1,:,:); |
| 247 | + |
| 248 | + v(k,:) = v(k-1,:); |
| 249 | + w(k,:) = w(k-1,:); |
| 250 | + da(k,:) = da(k-1,:); |
| 251 | + |
| 252 | + end |
| 253 | +end |
| 254 | + |
| 255 | +% Remove representation priors |
| 256 | +mu(1,:,:) = []; |
| 257 | +pi(1,:,:) = []; |
| 258 | + |
| 259 | +% Remove other dummy initial values |
| 260 | +muhat(1,:,:) = []; |
| 261 | +pihat(1,:,:) = []; |
| 262 | +v(1,:) = []; |
| 263 | +w(1,:) = []; |
| 264 | +da(1,:) = []; |
| 265 | +y(1) = []; |
| 266 | + |
| 267 | +% Responses on regular trials |
| 268 | +yreg = y; |
| 269 | +yreg(irr) =[]; |
| 270 | + |
| 271 | +% Implied learning rate at the first level |
| 272 | +mu2 = squeeze(mu(:,2,:)); |
| 273 | +mu2(irr,:) = []; |
| 274 | +mu2obs = mu2(sub2ind(size(mu2), (1:size(mu2,1))', yreg)); |
| 275 | + |
| 276 | +mu1hat = squeeze(muhat(:,1,:)); |
| 277 | +mu1hat(irr,:) = []; |
| 278 | +mu1hatobs = mu1hat(sub2ind(size(mu1hat), (1:size(mu1hat,1))', yreg)); |
| 279 | + |
| 280 | +upd1 = tapas_sgm(ka(1)*mu2obs,1) -mu1hatobs; |
| 281 | + |
| 282 | +dareg = da; |
| 283 | +dareg(irr,:) = []; |
| 284 | + |
| 285 | +lr1reg = upd1./dareg(:,1); |
| 286 | +lr1 = NaN(n-1,1); |
| 287 | +lr1(setdiff(1:n-1, irr)) = lr1reg; |
| 288 | + |
| 289 | +% Create result data structure |
| 290 | +traj = struct; |
| 291 | + |
| 292 | +traj.mu = mu; |
| 293 | +traj.sa = 1./pi; |
| 294 | + |
| 295 | +traj.muhat = muhat; |
| 296 | +traj.sahat = 1./pihat; |
| 297 | + |
| 298 | +traj.v = v; |
| 299 | +traj.w = w; |
| 300 | +traj.da = da; |
| 301 | + |
| 302 | +% Updates with respect to prediction |
| 303 | +traj.ud = mu -muhat; |
| 304 | + |
| 305 | +% Psi (precision weights on prediction errors) |
| 306 | +psi = NaN(n-1,l); |
| 307 | + |
| 308 | +pi2 = squeeze(pi(:,2,:)); |
| 309 | +pi2(irr,:) = []; |
| 310 | +pi2obs = pi2(sub2ind(size(pi2), (1:size(pi2,1))', yreg)); |
| 311 | + |
| 312 | +psi(setdiff(1:n-1, irr), 2) = 1./pi2obs; |
| 313 | + |
| 314 | +for i=3:l |
| 315 | + pihati = squeeze(pihat(:,i-1,:)); |
| 316 | + pihati(irr,:) = []; |
| 317 | + pihatiobs = pihati(sub2ind(size(pihati), (1:size(pihati,1))', yreg)); |
| 318 | + |
| 319 | + pii = squeeze(pi(:,i,:)); |
| 320 | + pii(irr,:) = []; |
| 321 | + piiobs = pii(sub2ind(size(pii), (1:size(pii,1))', yreg)); |
| 322 | + |
| 323 | + psi(setdiff(1:n-1, irr), i) = pihatiobs./piiobs; |
| 324 | +end |
| 325 | + |
| 326 | +traj.psi = psi; |
| 327 | + |
| 328 | +% Epsilons (precision-weighted prediction errors) |
| 329 | +epsi = NaN(n-1,l); |
| 330 | +epsi(:,2:l) = psi(:,2:l) .*da(:,1:l-1); |
| 331 | +traj.epsi = epsi; |
| 332 | + |
| 333 | +% Full learning rate (full weights on prediction errors) |
| 334 | +wt = NaN(n-1,l); |
| 335 | +wt(:,1) = lr1; |
| 336 | +wt(:,2) = psi(:,2); |
| 337 | +wt(:,3:l) = 1/2 *(v(:,2:l-1) *diag(ka(2:l-1))) .*psi(:,3:l); |
| 338 | +traj.wt = wt; |
| 339 | + |
| 340 | +% Create matrices for use by the observation model |
| 341 | +infStates = NaN(n-1,l,b,4); |
| 342 | +infStates(:,:,:,1) = traj.muhat; |
| 343 | +infStates(:,:,:,2) = traj.sahat; |
| 344 | +infStates(:,:,:,3) = traj.mu; |
| 345 | +infStates(:,:,:,4) = traj.sa; |
| 346 | + |
| 347 | +end |
0 commit comments