Skip to content

Commit

Permalink
fix turing for matrix latent
Browse files Browse the repository at this point in the history
  • Loading branch information
marius311 committed Aug 23, 2022
1 parent 56373da commit 5443ae3
Showing 1 changed file with 6 additions and 11 deletions.
17 changes: 6 additions & 11 deletions src/turing.jl
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,8 @@ function TuringMuseProblem(
error("Unsupposed backend from Turing: $(Turing.ADBACKEND)")
end
end
# ensure tuple
params = (params...,)
# prevent this constructor from advancing the default RNG for more clear reproducibility
rng = copy(Random.default_rng())
# model is expected to be passed in conditioned on x
Expand Down Expand Up @@ -160,14 +162,14 @@ end

function transform_θ(prob::TuringMuseProblem, θ)
vi = deepcopy(prob.vi_θ)
DynPPL.setval!(vi, θ)
DynPPL.setval!(vi, (;θ...))
DynPPL.link!(vi, DynPPL.SampleFromPrior())
ComponentVector(vi)
end

function inv_transform_θ(prob::TuringMuseProblem, θ)
vi = deepcopy(prob.vi_θ)
DynPPL.setval!(vi, θ)
DynPPL.setval!(vi, (;θ...))
for k in keys(θ)
DynPPL.settrans!!(vi, true, _VarName(k))
end
Expand Down Expand Up @@ -211,16 +213,9 @@ end



# helped to extract parameters from a sampled model. feels like there
# should be a less hacky way to do this...
function _namedtuple(vi::DynPPL.VarInfo)
map(DynPPL.TypedVarInfo(vi).metadata) do m
if m.vns[1] isa DynPPL.VarName{<:Any,Setfield.IdentityLens} && length(m.vals)==1
m.vals[1]
else
m.vals
end
end
# `values_as` seems to return Real arays so narrow eltype
map(x -> identity.(x), DynPPL.values_as(vi, NamedTuple))
end

ComponentVector(vi::DynPPL.VarInfo) = ComponentVector(_namedtuple(vi))
Expand Down

0 comments on commit 5443ae3

Please sign in to comment.