Skip to content

Commit fe8a8c0

Browse files
authored
Merge pull request #16 from marius311/turingthreads
fix Turing threads and test with threads
2 parents beda0dc + 3f60545 commit fe8a8c0

File tree

3 files changed

+29
-9
lines changed

3 files changed

+29
-9
lines changed

.github/workflows/tests_and_docs.yml

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ jobs:
1313
strategy:
1414
matrix:
1515
julia-version: ['1.7', '1.8', '1.9']
16+
threads: ['1', '2']
1617
fail-fast: false
1718
steps:
1819
- uses: actions/checkout@v2
@@ -31,5 +32,6 @@ jobs:
3132
env:
3233
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} # if authenticating with GitHub Actions token
3334
DOCUMENTER_KEY: ${{ secrets.DOCUMENTER_KEY }} # if authenticating with SSH deploy key
34-
BUILD_DOCS: ${{ matrix.julia-version == '1.7' }} # only build/deploy docs from one version
35+
BUILD_DOCS: ${{ matrix.julia-version == '1.7' && matrix.threads == '1'}} # only build/deploy docs from one version
36+
JULIA_NUM_THREADS: ${{ matrix.threads }}
3537
timeout-minutes: 30

src/turing.jl

Lines changed: 17 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -125,6 +125,10 @@ function TuringMuseProblem(
125125
error("Unsupposed backend from Turing: $(Turing.ADBACKEND)")
126126
end
127127
end
128+
if (Threads.nthreads() > 1) && hasmethod(AD.ZygoteBackend,Tuple{}) && (autodiff isa typeof(AD.ZygoteBackend()))
129+
error("Turing doesn't support using the Zygote backend when Threads.nthreads()>1. Use a different backend or a single-thread.")
130+
end
131+
128132
# ensure tuple
129133
params = (params...,)
130134
# prevent this constructor from advancing the default RNG for more clear reproducibility
@@ -187,13 +191,13 @@ standardizeθ(prob::TuringMuseProblem, θ::Number) =
187191

188192
function logLike(prob::TuringMuseProblem, x, z, θ, θ_space)
189193
trans = is_transformed(θ_space) ? prob.trans_z′_θ′ : prob.trans_z′_θ
190-
vi = DynPPL.SimpleVarInfo((;x..., z..., θ...), 0, trans)
194+
vi = DynPPL.SimpleVarInfo((;x..., z..., θ...), trans)
191195
DynPPL.logjoint(prob.model, vi)
192196
end
193197

194198
function logPriorθ(prob::TuringMuseProblem, θ, θ_space)
195199
trans = is_transformed(θ_space) ? prob.trans_z′_θ′ : prob.trans_z′_θ
196-
vi = DynPPL.SimpleVarInfo((;θ...), 0, trans)
200+
vi = DynPPL.SimpleVarInfo((;θ...), trans)
197201
DynPPL.logprior(prob.model_for_prior, vi)
198202
end
199203

@@ -210,13 +214,23 @@ end
210214

211215
function sample_x_z(prob::TuringMuseProblem, rng::AbstractRNG, θ)
212216
model = DynPPL.condition(prob.model, θ)
213-
vi = DynPPL.SimpleVarInfo((;), 0, prob.trans_z′_θ)
217+
vi = DynPPL.SimpleVarInfo((;θ...), prob.trans_z′_θ)
214218
vars = DynPPL.values_as(last(DynPPL.evaluate!!(model, rng, vi)), NamedTuple)
215219
(x = ComponentVector(select(vars, prob.observed_vars)), z = ComponentVector(select(vars, prob.latent_vars)))
216220
end
217221

218222

219223

224+
# benevolent type-piracy:
225+
function DynPPL.SimpleVarInfo(nt::NamedTuple, trans::DynPPL.AbstractTransformation)
226+
if isempty(nt)
227+
T = DynPPL.SIMPLEVARINFO_DEFAULT_ELTYPE
228+
else
229+
T = DynPPL.float_type_with_fallback(DynPPL.infer_nested_eltype(typeof(nt)))
230+
end
231+
DynPPL.SimpleVarInfo(nt, zero(T), trans)
232+
end
233+
220234
function _namedtuple(vi::DynPPL.VarInfo)
221235
# `values_as` seems to return Real arays so narrow eltype
222236
map(x -> identity.(x), DynPPL.values_as(vi, NamedTuple))

test/runtests.jl

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -22,11 +22,15 @@ rng = StableRNG(0)
2222
("Zygote", AD.ZygoteBackend())
2323
]
2424

25-
(;x) = rand(copy(rng), turing_funnel() |=0,))
26-
prob = TuringMuseProblem(turing_funnel() | (;x); autodiff)
27-
MuseInference.check_self_consistency(prob, (θ=1,), has_volume_factor=true, rng=copy(rng))
28-
result = muse(prob, (θ=1,); rng=copy(rng), get_covariance=true)
29-
@test result.dist.μ / result.dist.σ < 2
25+
if !(name=="Zygote" && Threads.nthreads()>1)
26+
27+
(;x) = rand(copy(rng), turing_funnel() |=0,))
28+
prob = TuringMuseProblem(turing_funnel() | (;x); autodiff)
29+
MuseInference.check_self_consistency(prob, (θ=1,), has_volume_factor=true, rng=copy(rng))
30+
result = muse(prob, (θ=1,); rng=copy(rng), get_covariance=true)
31+
@test result.dist.μ / result.dist.σ < 2
32+
33+
end
3034

3135
end
3236

0 commit comments

Comments
 (0)