Skip to content

Commit d6b3525

Browse files
committed
fix Turing threads and test with threads
1 parent beda0dc commit d6b3525

File tree

2 files changed

+16
-4
lines changed

2 files changed

+16
-4
lines changed

.github/workflows/tests_and_docs.yml

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ jobs:
1313
strategy:
1414
matrix:
1515
julia-version: ['1.7', '1.8', '1.9']
16+
threads: ['1', '2']
1617
fail-fast: false
1718
steps:
1819
- uses: actions/checkout@v2
@@ -31,5 +32,6 @@ jobs:
3132
env:
3233
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} # if authenticating with GitHub Actions token
3334
DOCUMENTER_KEY: ${{ secrets.DOCUMENTER_KEY }} # if authenticating with SSH deploy key
34-
BUILD_DOCS: ${{ matrix.julia-version == '1.7' }} # only build/deploy docs from one version
35+
BUILD_DOCS: ${{ matrix.julia-version == '1.7' && matrix.threads == '1'}} # only build/deploy docs from one version
36+
JULIA_NUM_THREADS: ${{ matrix.threads }}
3537
timeout-minutes: 30

src/turing.jl

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -187,13 +187,13 @@ standardizeθ(prob::TuringMuseProblem, θ::Number) =
187187

188188
function logLike(prob::TuringMuseProblem, x, z, θ, θ_space)
189189
trans = is_transformed(θ_space) ? prob.trans_z′_θ′ : prob.trans_z′_θ
190-
vi = DynPPL.SimpleVarInfo((;x..., z..., θ...), 0, trans)
190+
vi = DynPPL.SimpleVarInfo((;x..., z..., θ...), trans)
191191
DynPPL.logjoint(prob.model, vi)
192192
end
193193

194194
function logPriorθ(prob::TuringMuseProblem, θ, θ_space)
195195
trans = is_transformed(θ_space) ? prob.trans_z′_θ′ : prob.trans_z′_θ
196-
vi = DynPPL.SimpleVarInfo((;θ...), 0, trans)
196+
vi = DynPPL.SimpleVarInfo((;θ...), trans)
197197
DynPPL.logprior(prob.model_for_prior, vi)
198198
end
199199

@@ -210,13 +210,23 @@ end
210210

211211
function sample_x_z(prob::TuringMuseProblem, rng::AbstractRNG, θ)
212212
model = DynPPL.condition(prob.model, θ)
213-
vi = DynPPL.SimpleVarInfo((;), 0, prob.trans_z′_θ)
213+
vi = DynPPL.SimpleVarInfo((;θ...), prob.trans_z′_θ)
214214
vars = DynPPL.values_as(last(DynPPL.evaluate!!(model, rng, vi)), NamedTuple)
215215
(x = ComponentVector(select(vars, prob.observed_vars)), z = ComponentVector(select(vars, prob.latent_vars)))
216216
end
217217

218218

219219

220+
# benevolent type-piracy:
221+
function DynPPL.SimpleVarInfo(nt::NamedTuple, trans::DynPPL.AbstractTransformation)
222+
if isempty(nt)
223+
T = DynPPL.SIMPLEVARINFO_DEFAULT_ELTYPE
224+
else
225+
T = DynPPL.float_type_with_fallback(DynPPL.infer_nested_eltype(typeof(nt)))
226+
end
227+
DynPPL.SimpleVarInfo(nt, zero(T), trans)
228+
end
229+
220230
function _namedtuple(vi::DynPPL.VarInfo)
221231
# `values_as` seems to return Real arays so narrow eltype
222232
map(x -> identity.(x), DynPPL.values_as(vi, NamedTuple))

0 commit comments

Comments
 (0)