Skip to content

Commit ee7a595

Browse files
added two scripts to paper directory for simulated vs. CN and MC and added some text
1 parent a04b0ac commit ee7a595

File tree

3 files changed

+156
-2
lines changed

3 files changed

+156
-2
lines changed

paper/paper.md

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -44,12 +44,20 @@ bibliography: paper.bib
4444

4545
# Summary
4646

47-
Drift diffusion models (DDMs) are a popular model class for modeling a unobserved process that determines an subject's choice during a decision-making task [@Bogacz2006].
47+
Drift diffusion models (DDMs) are a popular model class for modeling a unobserved process that determines an subject's choice during a decision-making task [@Bogacz2006]. Mathematically, in their simplest form, they are equivalent of an [Ornstein-Uhlenbeck](https://en.wikipedia.org/wiki/Ornstein%E2%80%93Uhlenbeck_process) process, a type of mean-reverting stochastic process similar to Brownian motion, described by the following stochastic differential equation
4848

4949
```math
50-
dz = \lambda zdt + u(t)dt + \sigma dW
50+
dz = \lambda zdt + u(t)dt + \sigma dW \tag{1}
5151
```
5252

53+
These dynamics can be equivalently expressed as a partial differential equation, which describes the motion of the probability distribution of $z$
54+
55+
```math
56+
\frac{\partial P(z(t))}{\partial t} = \frac{\sigma}{2}\frac{\partial^2 P}{\partial z^2} - \frac{\partial(\lambda zP)}{\partial z} - \frac{\partial(u(t)P)}{\partial z}. \tag{2}
57+
```
58+
59+
In neuroscience, as mass of this distribution moves, this can be considered to be a simple model of the internal process by which evidence is accumlated and weighed between options. In cases where evidence is received continuously in time, an external input $u(t)$ is included. This external input forces our PDE to be solved numerically.
60+
5361
[@Brunton2013].
5462

5563
# Statement of need

paper/test.jl

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
using LinearAlgebra, PyPlot
2+
3+
# Parameters
4+
dt = 1e-2
5+
B = 10.
6+
λ = -4.
7+
σ2 = 10.
8+
n = 53
9+
10+
xc, dx = PulseInputDDM.bins(B, n);
11+
P0 = PulseInputDDM.P0(dx^2, n, dx, xc, dt)
12+
M = PulseInputDDM.transition_M(σ2*dt, λ, 0., dx, xc, n, dt);
13+
14+
# Calculate coefficients
15+
cDiff = σ2 * dt / (2*dx^2)
16+
cDrft = λ * dt / (2 * dx)
17+
Ddff = diagm(0 => ones(n) * -2cDiff, 1 => ones(n-1) * cDiff, -1 => ones(n-1) * cDiff)
18+
Ddff[:, [1, n]] .= 0
19+
Dder = diagm(0 => zeros(n), 1 => ones(n-1) * cDrft, -1 => ones(n-1) * -cDrft)
20+
Dder[:, [1, n]] .= 0
21+
C = exp(-Dder * diagm(0 => xc) + Ddff)
22+
23+
#plot(xc, P0, label="P0")
24+
#plot(xc, M^(1/dt) * P0, label="M")
25+
#plot(xc, C^(1/dt) * P0, label="C")
26+
#legend()
27+
28+
plot(sort(real(eigvals(M))), label="realM")
29+
plot(sort(abs.(imag(eigvals(M)))), label="imagM")
30+
plot(sort(real(eigvals(C))), label="realC")
31+
plot(sort(abs.(imag(eigvals(C)))), label="imagC")
32+
legend()

paper/test2.jl

Lines changed: 114 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,114 @@
1+
using LinearAlgebra, SparseArrays, Random, Distributions, Plots, StatsBase
2+
3+
# Clear all variables
4+
T = 1.0 # time of simulation
5+
dt = 2e-2 # Fokker-Planck approximations timestep
6+
nMC = 1e4 # number of Monte Carlo particles
7+
dtMC = 1e-4 # Monte Carlo timestep
8+
9+
B = 1.0 # bound height
10+
vara = 1e-2 # diffusion variance
11+
lambda = -1e-3 # drift
12+
mu0 = 0.0 * B # initial mean of distribution
13+
n = 53 # number of bins
14+
15+
dx = 2 * B / (n - 2) # bin width
16+
dx2 = dx^2 # dx^2
17+
vari = dx2 # make initial distribution as large as dx^2
18+
19+
# Finite difference
20+
21+
xe = vcat(-(B+dx), range(-B, -dx/2, ceil(Int, (n-3)/2+1))...,
22+
range(dx/2, B, floor(Int, (n-3)/2+1))..., B+dx) # edges of the bins
23+
xc = (xe[1:n] .+ dx/2)' # bin centers
24+
25+
cDiff = vara / (dx^2 * 2) # scale factor for diffusion
26+
cDrft = lambda / (dx * 2) # scale factor for drift
27+
28+
# Diffusion matrix
29+
Ddff = diagm(0 => -2 * cDiff * [0; ones(n-2); 0]) +
30+
diagm(-1 => cDiff * [0; ones(n-2)]) +
31+
diagm(1 => cDiff * [ones(n-2); 0])
32+
33+
# Drift matrix
34+
Dder = diagm(-1 => -cDrft * [0; ones(n-2)]) +
35+
diagm(1 => cDrft * [ones(n-2); 0])
36+
37+
# Multiply by a values and add
38+
DD = -Dder * diagm(0 => xc') + Ddff
39+
40+
# Matrix exponential
41+
M = exp(DD * dt)
42+
43+
# Brunton method
44+
45+
# You would replace this with your implementation of `make_F` in Julia
46+
M_B = PulseInputDDM.transition_M(vara*dt, lambda, 0., dx, xc', n, dt);
47+
48+
# Initialize
49+
50+
Pa = pdf.(Normal(mu0, sqrt(vari)), xc) .* dx
51+
Pa /= sum(Pa) # Fin. Diff.
52+
Pa = collect(Pa')
53+
Pa_B = copy(Pa) # Brunton
54+
55+
# For saving
56+
PA = zeros(n, round(Int, T/dt))
57+
PA_B = similar(PA)
58+
59+
# Propagate
60+
61+
for t in 1:round(Int, T/dt)
62+
global Pa, Pa_B # Ensure these refer to the global variables
63+
PA[:, t] = Pa
64+
PA_B[:, t] = Pa_B
65+
66+
Pa = M * Pa # Fin. Diff.
67+
Pa_B = M_B * Pa_B # Brunton (Uncomment this once you define `M_B`)
68+
69+
end
70+
71+
# Monte Carlo
72+
73+
# For saving
74+
PMC = fill(NaN, size(PA))
75+
76+
a = mu0 .+ sqrt(vari) .* randn(Int(nMC)) # Initialize Gaussian
77+
78+
for t in 1:round(Int, T/dtMC)
79+
80+
global a
81+
82+
a[a .< -(B+dx)] .= -(B+dx/2)
83+
a[a .> B+dx] .= B+dx/2
84+
85+
# Bin the individual particles
86+
if mod(t-1, round(Int, dt/dtMC)) + 1 == 1
87+
# Bin the particles and normalize by the number of Monte Carlo particles
88+
counts = StatsBase.fit(Histogram, a, xe).weights
89+
PMC[:, ceil(Int, dtMC*(t)/dt)] = (1/nMC) * counts
90+
#PMC[:, ceil(Int, dtMC*(t)/dt)] = (1/nMC) * histcounts(a, xe)
91+
end
92+
93+
go = (a .< B) .& (a .> -B) # Only integrate those that haven't crossed the boundary
94+
95+
# OU process: 2 terms, 1-drift and 2-diffusion
96+
a[go] .= a[go] .+ (lambda * dtMC) .* a[go] .+ sqrt(vara * dtMC) .* randn(sum(go))
97+
end
98+
99+
# Plot over time
100+
101+
p = plot()
102+
plot!(p, xc', PA[:, 1], label="Fin.Diff.", color="red", lw=2, marker=:o)
103+
plot!(p, xc', PA_B[:, 1], label="Brunton", color="green", lw=2, marker=:star)
104+
plot!(p, xc', PMC[:, 1], label="Monte Carlo", color="blue", lw=2, marker=:star)
105+
display(p)
106+
107+
for i in 1:size(PA, 2)
108+
p = plot()
109+
plot!(p, xc', PA[:, i], label="Fin.Diff.")
110+
plot!(p, xc', PA_B[:, i], label="Brunton")
111+
plot!(p, xc', PMC[:, i], label="Monte Carlo")
112+
sleep(0.1)
113+
display(p)
114+
end

0 commit comments

Comments
 (0)