Skip to content

Commit

Permalink
implicit diff improvements
Browse files Browse the repository at this point in the history
  • Loading branch information
marius311 committed Aug 23, 2022
1 parent b16b633 commit 56373da
Showing 1 changed file with 10 additions and 4 deletions.
14 changes: 10 additions & 4 deletions src/muse.jl
Original file line number Diff line number Diff line change
Expand Up @@ -306,6 +306,12 @@ function get_H!(

if implicit_diff

# check we can do implicit diff on this problem
(x, z) = sample_x_z(prob, copy(rng), θ₀)
if !(eltype(x) <:AbstractFloat && eltype(z) <:AbstractFloat)
error("implicit_diff=true requires elements of `x` and `z` to be `AbstractFloat`s (ie they must be continuous numbers).")
end

pbar = progress ? RemoteProgress(nsims_remaining, 0.1, "get_H: ") : nothing

append!(result.Hs, skipmissing(pmap(pool, rngs) do rng
Expand All @@ -318,11 +324,11 @@ function get_H!(
ad_fwd, ad_rev = AD.second_lowest(prob.autodiff), AD.lowest(prob.autodiff)

# non-implicit-diff term
H1 = first(AD.jacobian(θ₀, backend=ad_fwd) do θ
first(AD.gradient(θ₀, backend=ad_fwd) do θ
H1 = first(AD.jacobian(θ₀, backend=ad_fwd) do θ
first(AD.gradient(θ₀, backend=ad_fwd) do θ
logLike(prob, sample_x_z(prob, copy(rng), θ).x, ẑ, θ′, UnTransformedθ())
end)
end)
end)'

# term involving dzMAP/dθ via implicit-diff (w/ conjugate-gradient linear solve)
dFdθ = first(AD.jacobian(θ₀, backend=ad_fwd) do θ
Expand All @@ -343,7 +349,7 @@ function get_H!(
end)
end)
end
H2 = -(dFdθ' * cg(A, dFdθ1))
H2 = -(dFdθ' * mapreduce(w -> cg(A, w), hcat, eachcol(dFdθ1)))

H = H1 + H2
progress && ProgressMeter.next!(pbar)
Expand Down

0 comments on commit 56373da

Please sign in to comment.