Skip to content

Commit 3f60545

Browse files
committed
warn and dont test with threads/Zygote
1 parent d6b3525 commit 3f60545

File tree

2 files changed

+13
-5
lines changed

2 files changed

+13
-5
lines changed

src/turing.jl

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -125,6 +125,10 @@ function TuringMuseProblem(
125125
error("Unsupposed backend from Turing: $(Turing.ADBACKEND)")
126126
end
127127
end
128+
if (Threads.nthreads() > 1) && hasmethod(AD.ZygoteBackend,Tuple{}) && (autodiff isa typeof(AD.ZygoteBackend()))
129+
error("Turing doesn't support using the Zygote backend when Threads.nthreads()>1. Use a different backend or a single-thread.")
130+
end
131+
128132
# ensure tuple
129133
params = (params...,)
130134
# prevent this constructor from advancing the default RNG for more clear reproducibility

test/runtests.jl

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -22,11 +22,15 @@ rng = StableRNG(0)
2222
("Zygote", AD.ZygoteBackend())
2323
]
2424

25-
(;x) = rand(copy(rng), turing_funnel() |=0,))
26-
prob = TuringMuseProblem(turing_funnel() | (;x); autodiff)
27-
MuseInference.check_self_consistency(prob, (θ=1,), has_volume_factor=true, rng=copy(rng))
28-
result = muse(prob, (θ=1,); rng=copy(rng), get_covariance=true)
29-
@test result.dist.μ / result.dist.σ < 2
25+
if !(name=="Zygote" && Threads.nthreads()>1)
26+
27+
(;x) = rand(copy(rng), turing_funnel() |=0,))
28+
prob = TuringMuseProblem(turing_funnel() | (;x); autodiff)
29+
MuseInference.check_self_consistency(prob, (θ=1,), has_volume_factor=true, rng=copy(rng))
30+
result = muse(prob, (θ=1,); rng=copy(rng), get_covariance=true)
31+
@test result.dist.μ / result.dist.σ < 2
32+
33+
end
3034

3135
end
3236

0 commit comments

Comments
 (0)