From 4e66c412c7a14fa7a3b0a8990442186c81a2d17d Mon Sep 17 00:00:00 2001 From: marius Date: Sat, 26 Aug 2023 19:23:58 -0700 Subject: [PATCH] fix incorrect implicit diff H term --- src/muse.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/muse.jl b/src/muse.jl index 6ab49ff..4632fca 100644 --- a/src/muse.jl +++ b/src/muse.jl @@ -352,12 +352,12 @@ function get_H!( ad_fwd, ad_rev = AD.second_lowest(prob.autodiff), AD.lowest(prob.autodiff) ## non-implicit-diff term - H1 = implicit_diff_H1_is_zero ? 𝟘 : copyto!(similar(𝟘), permutedims(first(AD.jacobian(θ₀, backend=ad_fwd) do θ + H1 = implicit_diff_H1_is_zero ? 𝟘 : copyto!(similar(𝟘), first(AD.jacobian(θ₀, backend=ad_fwd) do θ local x, = sample_x_z(prob, copy(rng), θ) first(AD.gradient(θ₀, backend=ad_rev) do θ′ logLike(prob, x, ẑ, θ′, UnTransformedθ()) end) - end))) + end)) ## term involving dzMAP/dθ via implicit-diff (w/ conjugate-gradient linear solve) dFdθ = first(AD.jacobian(θ₀, backend=ad_fwd) do θ