Skip to content

Commit

Permalink
Type stability
Browse files Browse the repository at this point in the history
  • Loading branch information
chriselrod committed Jul 13, 2023
1 parent fba4498 commit 6771d15
Show file tree
Hide file tree
Showing 2 changed files with 63 additions and 26 deletions.
71 changes: 47 additions & 24 deletions src/datafit.jl
Original file line number Diff line number Diff line change
Expand Up @@ -206,24 +206,24 @@ function bayes_unpack_data(prob, p::AbstractVector{<:Pair})
(pdist, IndexKeyMap(prob, pkeys))
end

Turing.@model function bayesianODE(prob, t, pdist, pkeys, data, noise_prior)
Turing.@model function bayesianODE(prob, alg, t, pdist, pkeys, data, datamap, noise_prior)
σ ~ noise_prior

pprior ~ product_distribution(pdist)

prob = _remake(prob, (prob.tspan[1], t[end]), pkeys, pprior)
sol = solve(prob, saveat = t)
sol = solve(prob, alg, saveat = t)
if !SciMLBase.successful_retcode(sol)
Turing.DynamicPPL.acclogp!!(__varinfo__, -Inf)
return nothing
end
for i in eachindex(data)
data[i].second ~ MvNormal(sol[data[i].first], σ^2 * I)
data[i] ~ MvNormal(datamap(sol), σ^2 * I)
end
return nothing
end

Turing.@model function bayesianODE(prob,
Turing.@model function bayesianODE(prob, alg,
pdist,
pkeys,
ts,
Expand All @@ -236,7 +236,7 @@ Turing.@model function bayesianODE(prob,
pprior ~ product_distribution(pdist)

prob = _remake(prob, (prob.tspan[1], lastt), pkeys, pprior)
sol = solve(prob)
sol = solve(prob, alg)
if !SciMLBase.successful_retcode(sol)
Turing.DynamicPPL.acclogp!!(__varinfo__, -Inf)
return nothing
Expand Down Expand Up @@ -264,18 +264,19 @@ end
Base.length(ws::WeightedSol) = length(first(ws.sols))
Base.size(ws::WeightedSol) = (length(first(ws.sols)),)
function Base.getindex(ws::WeightedSol{T}, i::Int) where {T}
s = zero(T)
w = zero(T)
for j in eachindex(ws.weights)
s::T = zero(T)
w::T = zero(T)
@inbounds for j in eachindex(ws.weights)
w += ws.weights[j]
s += ws.weights[j] * ws.sols[j][i]
end
return s + (one(T) - w) * ws.sols[end][i]
end
function WeightedSol(sols, select, weights)
T = eltype(weights)
s = map(Base.Fix2(getindex, select), sols)
WeightedSol{T}(s, weights)
function WeightedSol(sols, select, i::Int, weights)
s = map(sols, select) do sol, sel
@view(sol[sel.indices[i], :])
end
WeightedSol{eltype(weights)}(s, weights)
end
function bayes_unpack_data(probs, p::Tuple{Vararg{<:AbstractVector{<:Pair}}}, data)
pdist, pkeys = bayes_unpack_data(probs, p)
Expand Down Expand Up @@ -305,43 +306,46 @@ function flatten(x::Tuple)
reduce(vcat, x), Grouper(map(length, x))
end

function getsols(probs, probspkeys, ppriors, t::AbstractArray)
map(probs, probspkeys, ppriors) do prob, pkeys, pprior
function getsols(probs, algs, probspkeys, ppriors, t::AbstractArray)
map(probs, algs, probspkeys, ppriors) do prob, alg, pkeys, pprior
newprob = _remake(prob, (prob.tspan[1], t[end]), pkeys, pprior)
solve(newprob, saveat = t)
solve(newprob, alg, saveat = t)
end
end
function getsols(probs, probspkeys, ppriors, lastt::Number)
map(probs, probspkeys, ppriors) do prob, pkeys, pprior
function getsols(probs, algs, probspkeys, ppriors, lastt::Number)
map(probs, algs, probspkeys, ppriors) do prob, alg, pkeys, pprior
newprob = _remake(prob, (prob.tspan[1], lastt), pkeys, pprior)
solve(newprob)
solve(newprob, alg)
end
end

Turing.@model function ensemblebayesianODE(probs::Union{Tuple, AbstractVector},
algs,
t,
pdist,
grouppriorsfunc,
probspkeys,
data,
datamaps,
noise_prior)
σ ~ noise_prior
ppriors ~ product_distribution(pdist)

Nprobs = length(probs)
Nprobs⁻¹ = inv(Nprobs)
weights ~ MvNormal(Distributions.Fill(Nprobs⁻¹, Nprobs - 1), Nprobs⁻¹)
sols = getsols(probs, probspkeys, grouppriorsfunc(ppriors), t)
sols = getsols(probs, algs, probspkeys, grouppriorsfunc(ppriors), t)
if !all(SciMLBase.successful_retcode, sols)
Turing.DynamicPPL.acclogp!!(__varinfo__, -Inf)
return nothing
end
for i in eachindex(data)
data[i].second ~ MvNormal(WeightedSol(sols, data[i].first, weights), σ^2 * I)
data[i] ~ MvNormal(WeightedSol(sols, datamaps, i, weights), σ^2 * I)
end
return nothing
end
Turing.@model function ensemblebayesianODE(probs::Union{Tuple, AbstractVector},
algs,
pdist,
grouppriorsfunc,
probspkeys,
Expand All @@ -353,7 +357,7 @@ Turing.@model function ensemblebayesianODE(probs::Union{Tuple, AbstractVector},
σ ~ noise_prior
ppriors ~ product_distribution(pdist)

sols = getsols(probs, probspkeys, grouppriorsfunc(ppriors), lastt)
sols = getsols(probs, algs, probspkeys, grouppriorsfunc(ppriors), lastt)

Nprobs = length(probs)
Nprobs⁻¹ = inv(Nprobs)
Expand Down Expand Up @@ -411,7 +415,14 @@ function bayesian_datafit(prob,
nchains = 4,
niter = 1000)
(pdist, pkeys) = bayes_unpack_data(prob, p)
model = bayesianODE(prob, t, pdist, pkeys, data, noise_prior)
model = bayesianODE(prob,
first(default_algorithm(prob)),
t,
pdist,
pkeys,
last.(data),
IndexKeyMap(prob, data),
noise_prior)
chain = Turing.sample(model,
Turing.NUTS(0.65),
mcmcensemble,
Expand All @@ -430,7 +441,15 @@ function bayesian_datafit(prob,
nchains = 4,
niter = 1_000)
pdist, pkeys, ts, lastt, timeseries, datakeys = bayes_unpack_data(prob, p, data)
model = bayesianODE(prob, pdist, pkeys, ts, lastt, timeseries, datakeys, noise_prior)
model = bayesianODE(prob,
first(default_algorithm(prob)),
pdist,
pkeys,
ts,
lastt,
timeseries,
datakeys,
noise_prior)
chain = Turing.sample(model,
Turing.NUTS(0.65),
mcmcensemble,
Expand All @@ -451,7 +470,10 @@ function bayesian_datafit(probs::Union{Tuple, AbstractVector},
(pdist_, pkeys) = bayes_unpack_data(p)
pdist, grouppriorsfunc = flatten(pdist_)

model = ensemblebayesianODE(probs, t, pdist, grouppriorsfunc, pkeys, data, noise_prior)
model = ensemblebayesianODE(probs,
map(first default_algorithm, probs),
t, pdist, grouppriorsfunc, pkeys, last.(data),
map(Base.Fix2(IndexKeyMap, data), probs), noise_prior)
chain = Turing.sample(model,
Turing.NUTS(0.65),
mcmcensemble,
Expand All @@ -472,6 +494,7 @@ function bayesian_datafit(probs::Union{Tuple, AbstractVector},
pdist_, pkeys, ts, lastt, timeseries, datakeys = bayes_unpack_data(p, data)
pdist, grouppriorsfunc = flatten(pdist_)
model = ensemblebayesianODE(probs,
map(first default_algorithm, probs),
pdist,
grouppriorsfunc,
pkeys,
Expand Down
18 changes: 16 additions & 2 deletions src/keyindexmap.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ struct IndexKeyMap
indices::Vector{Int}
end

# probs support
function IndexKeyMap(prob, keys)
params = ModelingToolkit.parameters(prob.f.sys)
indices = Vector{Int}(undef, length(keys))
Expand All @@ -12,7 +13,8 @@ function IndexKeyMap(prob, keys)
return IndexKeyMap(indices)
end

Base.@propagate_inbounds function (ikm::IndexKeyMap)(prob, v::AbstractVector)
Base.@propagate_inbounds function (ikm::IndexKeyMap)(prob::SciMLBase.AbstractDEProblem,
v::AbstractVector)
@boundscheck checkbounds(v, length(ikm.indices))
def = prob.p
ret = Vector{Base.promote_eltype(v, def)}(undef, length(def))
Expand All @@ -22,8 +24,20 @@ Base.@propagate_inbounds function (ikm::IndexKeyMap)(prob, v::AbstractVector)
end
return ret
end

function _remake(prob, tspan, ikm::IndexKeyMap, pprior)
p = ikm(prob, pprior)
remake(prob; tspan, p)
end

# data support
function IndexKeyMap(prob, data::AbstractVector{<:Pair})
states = ModelingToolkit.states(prob.f.sys)
indices = Vector{Int}(undef, length(data))
for i in eachindex(data)
indices[i] = findfirst(Base.Fix1(isequal, data[i].first), states)
end
return IndexKeyMap(indices)
end
function (ikm::IndexKeyMap)(sol::SciMLBase.AbstractTimeseriesSolution)
(@view(sol[i, :]) for i in ikm.indices)
end

0 comments on commit 6771d15

Please sign in to comment.