From 1c23df31e9d0f9ec26404b9149a2a4b700444bc0 Mon Sep 17 00:00:00 2001 From: marius Date: Tue, 23 Apr 2024 13:59:57 -0700 Subject: [PATCH] fix docs --- .github/workflows/tests_and_docs.yml | 9 +------- docs/src/index.md | 32 +++++++--------------------- test/Project.toml | 2 +- 3 files changed, 10 insertions(+), 33 deletions(-) diff --git a/.github/workflows/tests_and_docs.yml b/.github/workflows/tests_and_docs.yml index cdf8d41..d8ab61d 100644 --- a/.github/workflows/tests_and_docs.yml +++ b/.github/workflows/tests_and_docs.yml @@ -12,7 +12,7 @@ jobs: runs-on: ubuntu-latest strategy: matrix: - julia-version: ['1.7', '1.8', '1.9', '~1.10.0-0'] + julia-version: ['1.7', '1.8', '1.9', '1.10'] threads: ['1', '2'] fail-fast: false steps: @@ -20,13 +20,6 @@ jobs: - uses: julia-actions/setup-julia@v1 with: version: ${{ matrix.julia-version }} - - uses: actions/setup-python@v2 - with: - python-version: '3.8' - - name: Install dependencies - run: | - python -m pip install --upgrade pip - pip install matplotlib - uses: julia-actions/julia-buildpkg@v1 - uses: julia-actions/julia-runtest@v1 env: diff --git a/docs/src/index.md b/docs/src/index.md index 256b9c3..40d8d82 100644 --- a/docs/src/index.md +++ b/docs/src/index.md @@ -33,9 +33,8 @@ First, load up the packages we'll need: ```@example 1 using MuseInference, Turing -using AbstractDifferentiation, Dates, LinearAlgebra, Printf, PyPlot, Random, Zygote +using AbstractDifferentiation, Dates, LinearAlgebra, Printf, Plots, Random, Zygote Turing.setadbackend(:zygote) -PyPlot.ioff() # hide using Logging # hide Logging.disable_logging(Logging.Info) # hide Turing.AdvancedVI.PROGRESS[] = false # hide @@ -85,7 +84,7 @@ nothing # hide We next compute the MUSE estimate for the same problem. To reach the same Monte Carlo error as HMC, the number of MUSE simulations should be the same as the effective sample size of the chain we just ran. This is: ```@example 1 -nsims = round(Int, ess_rhat(chain)[:θ,:ess]) +nsims = round(Int, ess(chain)[:θ,:ess]) ``` Running the MUSE estimate, @@ -97,29 +96,14 @@ muse_result = @time muse(model, 0; nsims, get_covariance=true) nothing # hide ``` -Lets also try mean-field variational inference (MFVI) to compare to another approximate method. +Now let's plot the different estimates. In this case, MUSE gives a nearly perfect answer in a fraction of the time. ```@example 1 -Random.seed!(4) -vi(model, ADVI(10, 10)) # warmup # hide -t_vi = @time @elapsed vi_result = vi(model, ADVI(10, 1000)) -nothing # hide -``` - -Now let's plot the different estimates. In this case, MUSE gives a nearly perfect answer at a fraction of the computational cost. MFVI struggles in both speed and accuracy by comparison. - -```@example 1 -figure(figsize=(6,5)) # hide -axvline(0, c="k", ls="--", alpha=0.5) -hist(collect(chain["θ"][:]), density=true, bins=15, label=@sprintf("HMC (%.1f seconds)", chain.info.stop_time - chain.info.start_time)) +histogram(collect(chain["θ"][:]), normalize=:pdf, bins=10, label=@sprintf("HMC (%.1f seconds)", chain.info.stop_time - chain.info.start_time)) θs = range(-0.5,0.5,length=1000) -plot(θs, pdf.(muse_result.dist, θs), label=@sprintf("MUSE (%.1f seconds)", (muse_result.time / Millisecond(1000)))) -plot(θs, pdf.(Normal(vi_result.dist.m[1], vi_result.dist.σ[1]), θs), label=@sprintf("MFVI (%.1f seconds)", t_vi)) -legend() -xlabel(L"\theta") -ylabel(L"\mathcal{P}(\theta\,|\,x)") -title("2048-dimensional noisy funnel") -gcf() # hide +plot!(θs, pdf.(muse_result.dist, θs), label=@sprintf("MUSE (%.1f seconds)", (muse_result.time / Millisecond(1000))), lw=2) +vline!([0], c=:black, ls=:dash, alpha=0.5, label=nothing) +plot!(xlabel=L"\theta", ylabel=L"\mathcal{P}(\theta\,|\,x)", title="2048-dimensional noisy funnel") ``` The timing[^1] difference is indicative of the speedups over HMC that are possible. These get even more dramatic as we increase dimensionality, which is why MUSE really excels on high-dimensional problems. @@ -180,7 +164,7 @@ prob = SimpleMuseProblem( function logPrior(θ) -θ^2/(2*3^2) end; - autodiff = AD.ZygoteBackend() + autodiff = AbstractDifferentiation.ZygoteBackend() ) nothing # hide ``` diff --git a/test/Project.toml b/test/Project.toml index 02fb19b..04bafc1 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -12,7 +12,7 @@ MCMCDiagnosticTools = "be115224-59cd-429b-ad48-344e309966f0" MeasureTheory = "eadaa1a4-d27c-401d-8699-e962e1bbc33b" Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f" Printf = "de0858da-6303-5e67-8744-51eddeeeb8d7" -PyPlot = "d330b81b-6aea-500a-939a-2ce795aea3ee" +Plots = "91a5bcdd-55d7-5caf-9e0b-520d859cae80" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" Soss = "8ce77f84-9b61-11e8-39ff-d17a774bf41c" StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3"