Skip to content

Commit

Permalink
Merge branch 'turing_0_28'
Browse files Browse the repository at this point in the history
  • Loading branch information
marius311 committed Aug 27, 2023
2 parents 4e66c41 + 0764734 commit cb6def9
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 4 deletions.
8 changes: 6 additions & 2 deletions src/turing.jl
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,10 @@ function DynPPL.istrans(vi::DynPPL.SimpleVarInfo{NT,T,<:PartialTransformation},
vn in vi.transformation.transformed_vns
end

function DynPPL.maybe_invlink_before_eval!!(vi::DynPPL.SimpleVarInfo{NT,T,<:PartialTransformation}, context::DynPPL.AbstractContext, model::DynPPL.Model) where {NT,T}
vi
end


struct TuringMuseProblem{A<:AD.AbstractBackend, M<:Turing.Model} <: AbstractMuseProblem

Expand Down Expand Up @@ -163,7 +167,7 @@ end
function transform_θ(prob::TuringMuseProblem, θ)
vi = deepcopy(prob.vi_θ)
DynPPL.setval!(vi, (;θ...))
DynPPL.link!(vi, DynPPL.SampleFromPrior())
DynPPL.link!!(vi, DynPPL.SampleFromPrior(), prob.model)
ComponentVector(vi)
end

Expand All @@ -173,7 +177,7 @@ function inv_transform_θ(prob::TuringMuseProblem, θ)
for k in keys(θ)
DynPPL.settrans!!(vi, true, _VarName(k))
end
DynPPL.invlink!(vi, DynPPL.SampleFromPrior())
DynPPL.invlink!!(vi, DynPPL.SampleFromPrior(), prob.model)
ComponentVector(vi)
end

Expand Down
3 changes: 1 addition & 2 deletions test/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@ Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f"
Printf = "de0858da-6303-5e67-8744-51eddeeeb8d7"
PyPlot = "d330b81b-6aea-500a-939a-2ce795aea3ee"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
SampleChainsDynamicHMC = "6d9fd711-e8b2-4778-9c70-c1dfb499d4c4"
Soss = "8ce77f84-9b61-11e8-39ff-d17a774bf41c"
StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
Expand All @@ -24,4 +23,4 @@ Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"

[compat]
Soss = "0.21.2"
Turing = "0.21.10"
Turing = "0.28"

0 comments on commit cb6def9

Please sign in to comment.