From 5443ae32a19f7b909c70b83f6c069ea43c881d01 Mon Sep 17 00:00:00 2001 From: marius Date: Tue, 23 Aug 2022 00:47:26 -0500 Subject: [PATCH] fix turing for matrix latent --- src/turing.jl | 17 ++++++----------- 1 file changed, 6 insertions(+), 11 deletions(-) diff --git a/src/turing.jl b/src/turing.jl index 2872f5d..9cdd03f 100644 --- a/src/turing.jl +++ b/src/turing.jl @@ -121,6 +121,8 @@ function TuringMuseProblem( error("Unsupposed backend from Turing: $(Turing.ADBACKEND)") end end + # ensure tuple + params = (params...,) # prevent this constructor from advancing the default RNG for more clear reproducibility rng = copy(Random.default_rng()) # model is expected to be passed in conditioned on x @@ -160,14 +162,14 @@ end function transform_θ(prob::TuringMuseProblem, θ) vi = deepcopy(prob.vi_θ) - DynPPL.setval!(vi, θ) + DynPPL.setval!(vi, (;θ...)) DynPPL.link!(vi, DynPPL.SampleFromPrior()) ComponentVector(vi) end function inv_transform_θ(prob::TuringMuseProblem, θ) vi = deepcopy(prob.vi_θ) - DynPPL.setval!(vi, θ) + DynPPL.setval!(vi, (;θ...)) for k in keys(θ) DynPPL.settrans!!(vi, true, _VarName(k)) end @@ -211,16 +213,9 @@ end -# helped to extract parameters from a sampled model. feels like there -# should be a less hacky way to do this... function _namedtuple(vi::DynPPL.VarInfo) - map(DynPPL.TypedVarInfo(vi).metadata) do m - if m.vns[1] isa DynPPL.VarName{<:Any,Setfield.IdentityLens} && length(m.vals)==1 - m.vals[1] - else - m.vals - end - end + # `values_as` seems to return Real arays so narrow eltype + map(x -> identity.(x), DynPPL.values_as(vi, NamedTuple)) end ComponentVector(vi::DynPPL.VarInfo) = ComponentVector(_namedtuple(vi))