@@ -125,6 +125,10 @@ function TuringMuseProblem(
125
125
error (" Unsupposed backend from Turing: $(Turing. ADBACKEND) " )
126
126
end
127
127
end
128
+ if (Threads. nthreads () > 1 ) && hasmethod (AD. ZygoteBackend,Tuple{}) && (autodiff isa typeof (AD. ZygoteBackend ()))
129
+ error (" Turing doesn't support using the Zygote backend when Threads.nthreads()>1. Use a different backend or a single-thread." )
130
+ end
131
+
128
132
# ensure tuple
129
133
params = (params... ,)
130
134
# prevent this constructor from advancing the default RNG for more clear reproducibility
@@ -187,13 +191,13 @@ standardizeθ(prob::TuringMuseProblem, θ::Number) =
187
191
188
192
function logLike (prob:: TuringMuseProblem , x, z, θ, θ_space)
189
193
trans = is_transformed (θ_space) ? prob. trans_z′_θ′ : prob. trans_z′_θ
190
- vi = DynPPL. SimpleVarInfo ((;x... , z... , θ... ), 0 , trans)
194
+ vi = DynPPL. SimpleVarInfo ((;x... , z... , θ... ), trans)
191
195
DynPPL. logjoint (prob. model, vi)
192
196
end
193
197
194
198
function logPriorθ (prob:: TuringMuseProblem , θ, θ_space)
195
199
trans = is_transformed (θ_space) ? prob. trans_z′_θ′ : prob. trans_z′_θ
196
- vi = DynPPL. SimpleVarInfo ((;θ... ), 0 , trans)
200
+ vi = DynPPL. SimpleVarInfo ((;θ... ), trans)
197
201
DynPPL. logprior (prob. model_for_prior, vi)
198
202
end
199
203
@@ -210,13 +214,23 @@ end
210
214
211
215
function sample_x_z (prob:: TuringMuseProblem , rng:: AbstractRNG , θ)
212
216
model = DynPPL. condition (prob. model, θ)
213
- vi = DynPPL. SimpleVarInfo ((;), 0 , prob. trans_z′_θ)
217
+ vi = DynPPL. SimpleVarInfo ((;θ ... ) , prob. trans_z′_θ)
214
218
vars = DynPPL. values_as (last (DynPPL. evaluate!! (model, rng, vi)), NamedTuple)
215
219
(x = ComponentVector (select (vars, prob. observed_vars)), z = ComponentVector (select (vars, prob. latent_vars)))
216
220
end
217
221
218
222
219
223
224
+ # benevolent type-piracy:
225
+ function DynPPL. SimpleVarInfo (nt:: NamedTuple , trans:: DynPPL.AbstractTransformation )
226
+ if isempty (nt)
227
+ T = DynPPL. SIMPLEVARINFO_DEFAULT_ELTYPE
228
+ else
229
+ T = DynPPL. float_type_with_fallback (DynPPL. infer_nested_eltype (typeof (nt)))
230
+ end
231
+ DynPPL. SimpleVarInfo (nt, zero (T), trans)
232
+ end
233
+
220
234
function _namedtuple (vi:: DynPPL.VarInfo )
221
235
# `values_as` seems to return Real arays so narrow eltype
222
236
map (x -> identity .(x), DynPPL. values_as (vi, NamedTuple))
0 commit comments