Skip to content

Commit

Permalink
do H MAPs on worker
Browse files Browse the repository at this point in the history
  • Loading branch information
marius311 committed Jul 13, 2023
1 parent 5d1fa85 commit 2391d36
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 4 deletions.
11 changes: 7 additions & 4 deletions src/muse.jl
Original file line number Diff line number Diff line change
Expand Up @@ -340,10 +340,13 @@ function get_H!(

try

(x, z) = sample_x_z(prob, copy(rng), θ₀)
z_start = @something(z₀, ẑ_guess_from_truth(prob, x, z, θ₀))
ẑ, = ẑ_at_θ(prob, x, z_start, θ₀, ∇z_logLike_atol=1e-1)
pbar == nothing || ProgressMeter.next!(pbar)
(x, z, z_start, ẑ) = remotecall_fetch(pool_jac) do
(x, z) = sample_x_z(prob, copy(rng), θ₀)
z_start = @something(z₀, ẑ_guess_from_truth(prob, x, z, θ₀))
ẑ, = ẑ_at_θ(prob, x, z_start, θ₀, ∇z_logLike_atol=1e-1)
pbar == nothing || ProgressMeter.next!(pbar)
(x, z, z_start, ẑ)
end
T = eltype(z_start)

ad_fwd, ad_rev = AD.second_lowest(prob.autodiff), AD.lowest(prob.autodiff)
Expand Down
1 change: 1 addition & 0 deletions src/util.jl
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@ AD.jacobian(f, args...; backend::AD.AbstractBackend) = AD.jacobian(backend, f, a
# worker pool which just falls back to map
struct LocalWorkerPool <: AbstractWorkerPool end
Distributed.pmap(f, ::LocalWorkerPool, args...) = map(f, args...)
Distributed.remotecall_fetch(f, ::LocalWorkerPool, args...) = f(args...)

# worker pool which is equivalent to passing batch_size to pmap
struct BatchWorkerPool <: AbstractWorkerPool
Expand Down

0 comments on commit 2391d36

Please sign in to comment.