Skip to content

Commit

Permalink
update transition
Browse files Browse the repository at this point in the history
  • Loading branch information
oameye committed Apr 23, 2024
1 parent 1f6cc1c commit 91ff50b
Show file tree
Hide file tree
Showing 3 changed files with 138 additions and 132 deletions.
5 changes: 1 addition & 4 deletions src/CriticalTransitions.jl
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,8 @@ using Printf, DrWatson, Dates, Statistics
include("utils.jl")
include("CoupledSDEs.jl")
include("io.jl")
# include("systemanalysis/stability.jl")
# include("systemanalysis/basinsofattraction.jl")
# include("systemanalysis/basinboundary.jl")
include("trajectories/simulation.jl")
# include("trajectories/transition.jl")
include("trajectories/transition.jl")
# include("largedeviations/action.jl")
# include("largedeviations/min_action_method.jl")
# include("largedeviations/geometric_min_action_method.jl")
Expand Down
247 changes: 119 additions & 128 deletions src/trajectories/transition.jl
Original file line number Diff line number Diff line change
@@ -1,54 +1,44 @@
"""
transition(sys::StochSystem, x_i::State, x_f::State; kwargs...)
Generates a sample transition from point `x_i` to point `x_f`.
This function simulates `sys` in time, starting from initial condition `x_i`, until entering a `length(sys.u)`-dimensional ball of radius `rad_f` around `x_f`.
## Keyword arguments
* `rad_i=0.1`: radius of ball around `x_i`
* `rad_f=0.1`: radius of ball around `x_f`
* `cut_start=true`: if `false`, returns the whole trajectory up to the transition
* `dt=0.01`: time step of integration
* `tmax=1e3`: maximum time when the simulation stops even `x_f` has not been reached
* `rad_dims=1:length(sys.u)`: the directions in phase space to consider when calculating the radii
`rad_i` and `rad_f`. Defaults to all directions. To consider only a subspace of state space,
insert a vector of indices of the dimensions to be included.
* `solver=EM()`: numerical solver. Defaults to Euler-Mayurama.
* `progress`: shows a progress bar with respect to `tmax`
## Output
`[path, times, success]`
* `path` (Matrix): transition path (size [dim × N], where N is the number of time points)
* `times` (Vector): time values (since start of simulation) of the path points (size N)
* `success` (bool): if `true`, a transition occured (i.e. the ball around `x_f` has been reached), else `false`
* `kwargs...`: keyword arguments passed to [`simulate`](@ref)
See also [`transitions`](@ref), [`simulate`](@ref).
"""
function transition(sys::StochSystem, x_i::State, x_f::State;
# """
# transition(sys::StochSystem, x_i::State, x_f::State; kwargs...)
# Generates a sample transition from point `x_i` to point `x_f`.

# This function simulates `sys` in time, starting from initial condition `x_i`, until entering a `length(sys.u)`-dimensional ball of radius `rad_f` around `x_f`.

# ## Keyword arguments
# * `rad_i=0.1`: radius of ball around `x_i`
# * `rad_f=0.1`: radius of ball around `x_f`
# * `cut_start=true`: if `false`, returns the whole trajectory up to the transition
# * `dt=0.01`: time step of integration
# * `tmax=1e3`: maximum time when the simulation stops even `x_f` has not been reached
# * `rad_dims=1:length(sys.u)`: the directions in phase space to consider when calculating the radii
# `rad_i` and `rad_f`. Defaults to all directions. To consider only a subspace of state space,
# insert a vector of indices of the dimensions to be included.
# * `solver=EM()`: numerical solver. Defaults to Euler-Mayurama.
# * `progress`: shows a progress bar with respect to `tmax`

# ## Output
# `[path, times, success]`
# * `path` (Matrix): transition path (size [dim × N], where N is the number of time points)
# * `times` (Vector): time values (since start of simulation) of the path points (size N)
# * `success` (bool): if `true`, a transition occured (i.e. the ball around `x_f` has been reached), else `false`
# * `kwargs...`: keyword arguments passed to [`simulate`](@ref)

# See also [`transitions`](@ref), [`simulate`](@ref).
# """
function transition(sys::CoupledSDEs, x_i, x_f;
rad_i=0.1,
rad_f=0.1,
dt=0.01,
tmax=1e3,
solver=EM(),
showprogress=true,
cut_start=true,
rad_dims=1:length(sys.u),
cut_start=false,
rad_dims=1:length(current_state(sys)),
kwargs...)

condition(u,t,integrator) = subnorm(u - x_f; directions=rad_dims) < rad_f
affect!(integrator) = terminate!(integrator)
cb_ball = DiscreteCallback(condition, affect!)

sim = simulate(sys, x_i;
dt=dt, tmax=tmax, solver=solver,
callback=cb_ball, progress=showprogress,
kwargs...)

success = true
if sim.t[end] == tmax
success = false
end
sim = simulate(sys, tmax, x_i; callback=cb_ball, kwargs...)
success = sim.retcode == SciMLBase.ReturnCode.Terminated

simt = sim.t
if cut_start
Expand All @@ -57,6 +47,7 @@ function transition(sys::StochSystem, x_i::State, x_f::State;
while dist > rad_i
idx -= 1
dist = norm(sim[:,idx] - x_i)
idx < 1 && error("Trajactory never left the initial state sphere. Increase tmax or decrease rad_i.")
end
sim = sim[:,idx:end]
simt = simt[idx:end]
Expand Down Expand Up @@ -103,89 +94,89 @@ The `savefile` keyword argument allows saving the data to a `.jld2` or `.h5` fil
> An example script using `transitions` is available [here](https://github.com/juliadynamics/CriticalTransitions.jl/blob/main/scripts/sample_transitions_h5.jl).
"""
function transitions(sys::StochSystem, x_i::State, x_f::State, N=1;
rad_i=0.1,
rad_f=0.1,
dt=0.01,
tmax=1e3,
Nmax=1000,
solver=EM(),
cut_start=true,
rad_dims=1:length(sys.u),
savefile=nothing,
output_level=1,
showprogress::Bool=false,
kwargs...)
"""
Generates N transition samples of sys from x_i to x_f.
Supports multi-threading.
rad_i: ball radius around x_i
rad_f: ball radius around x_f
cut_start: if false, saves the whole trajectory up to the transition
savefile: if not nothing, saves data to a specified open .jld2 file
"""

samples, times, idx::Vector{Int64}, r_idx::Vector{Int64} = [], [], [], []

iterator = showprogress ? tqdm(1:Nmax) : 1:Nmax

Threads.@threads for j iterator

sim, simt, success = transition(sys, x_i, x_f;
rad_i=rad_i, rad_f=rad_f, rad_dims=rad_dims, dt=dt, tmax=tmax,
solver=solver, progress=false, cut_start=cut_start, kwargs...)

if success

if showprogress
print("\rStatus: $(length(idx)+1)/$(N) transitions complete.")
end

if savefile == nothing
push!(samples, sim);
push!(times, simt);
else # store or save in .jld2/.h5 file
write(savefile, "paths/path "*string(j), sim)
write(savefile, "times/times "*string(j), simt)
end

push!(idx, j)

if length(idx) > max(1, N - Threads.nthreads())
break
else
continue
end
else
push!(r_idx, j)
end
end

(output_level == 1) ? (return samples, times, idx, length(r_idx)) : nothing

if output_level == 2
success_rate = length(idx)/(length(r_idx)+length(idx))
mean_res_time = sum([times[i][1] for i in 1:length(times)]) + tmax*length(r_idx)
mean_trans_time = mean([(times[i][end]-times[i][1]) for i in 1:length(times)])

return TransitionPathEnsemble(samples, times, success_rate, mean_res_time, mean_trans_time)
end
end;

struct TransitionPathEnsemble
paths::Vector
times::Vector
success_rate::Real
t_res::Real
t_trans::Real
end;

function prettyprint(tpe::TransitionPathEnsemble)
"Transition path ensemble of $(length(tpe.times)) samples
- sampling success rate: $(round(tpe.success_rate, digits=3))
- mean residence time: $(round(tpe.t_res, digits=3))
- mean transition time: $(round(tpe.t_trans, digits=3))
- normalized transition rate: $(round(tpe.t_res/tpe.t_trans, digits=1))"
end

Base.show(io::IO, tpe::TransitionPathEnsemble) = print(io, prettyprint(tpe))
# function transitions(sys::StochSystem, x_i::State, x_f::State, N=1;
# rad_i=0.1,
# rad_f=0.1,
# dt=0.01,
# tmax=1e3,
# Nmax=1000,
# solver=EM(),
# cut_start=true,
# rad_dims=1:length(sys.u),
# savefile=nothing,
# output_level=1,
# showprogress::Bool=false,
# kwargs...)
# """
# Generates N transition samples of sys from x_i to x_f.
# Supports multi-threading.
# rad_i: ball radius around x_i
# rad_f: ball radius around x_f
# cut_start: if false, saves the whole trajectory up to the transition
# savefile: if not nothing, saves data to a specified open .jld2 file
# """

# samples, times, idx::Vector{Int64}, r_idx::Vector{Int64} = [], [], [], []

# iterator = showprogress ? tqdm(1:Nmax) : 1:Nmax

# Threads.@threads for j ∈ iterator

# sim, simt, success = transition(sys, x_i, x_f;
# rad_i=rad_i, rad_f=rad_f, rad_dims=rad_dims, dt=dt, tmax=tmax,
# solver=solver, progress=false, cut_start=cut_start, kwargs...)

# if success

# if showprogress
# print("\rStatus: $(length(idx)+1)/$(N) transitions complete.")
# end

# if savefile == nothing
# push!(samples, sim);
# push!(times, simt);
# else # store or save in .jld2/.h5 file
# write(savefile, "paths/path "*string(j), sim)
# write(savefile, "times/times "*string(j), simt)
# end

# push!(idx, j)

# if length(idx) > max(1, N - Threads.nthreads())
# break
# else
# continue
# end
# else
# push!(r_idx, j)
# end
# end

# (output_level == 1) ? (return samples, times, idx, length(r_idx)) : nothing

# if output_level == 2
# success_rate = length(idx)/(length(r_idx)+length(idx))
# mean_res_time = sum([times[i][1] for i in 1:length(times)]) + tmax*length(r_idx)
# mean_trans_time = mean([(times[i][end]-times[i][1]) for i in 1:length(times)])

# return TransitionPathEnsemble(samples, times, success_rate, mean_res_time, mean_trans_time)
# end
# end;

# struct TransitionPathEnsemble
# paths::Vector
# times::Vector
# success_rate::Real
# t_res::Real
# t_trans::Real
# end;

# function prettyprint(tpe::TransitionPathEnsemble)
# "Transition path ensemble of $(length(tpe.times)) samples
# - sampling success rate: $(round(tpe.success_rate, digits=3))
# - mean residence time: $(round(tpe.t_res, digits=3))
# - mean transition time: $(round(tpe.t_trans, digits=3))
# - normalized transition rate: $(round(tpe.t_res/tpe.t_trans, digits=1))"
# end

# Base.show(io::IO, tpe::TransitionPathEnsemble) = print(io, prettyprint(tpe))
18 changes: 18 additions & 0 deletions test/trajactories/transition.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@

@testset "fitzhugh_nagumo" begin
p = [1.0, 3.0, 1.0, 1.0, 1.0, 0.0] # Parameters (ϵ, β, α, γ, κ, I)
σ = 0.24 # noise strength

# StochSystem
sys = CoupledSDEs(fitzhugh_nagumo, diag_noise_funtion(σ), zeros(2), p)

# Calculate fixed points
ds = CoupledODEs(sys)
box = intervals_to_box([-2, -2], [2, 2])
eqs, eigs, stab = fixedpoints(ds, box)

# Store the two stable fixed points
fp1, fp2 = eqs[stab]

transition(sys, fp1, fp2)
end

0 comments on commit 91ff50b

Please sign in to comment.