From 9e4c98b10790b0cd7da155f0e58b14a3a62b986b Mon Sep 17 00:00:00 2001 From: Chao Zhang Date: Wed, 8 Nov 2023 14:41:27 +0100 Subject: [PATCH] update demo35 to test NUTS and ULA with FD approximated gradients --- demos/demo35_FD_gradient.py | 32 ++++++++++++++++++++++++++------ 1 file changed, 26 insertions(+), 6 deletions(-) diff --git a/demos/demo35_FD_gradient.py b/demos/demo35_FD_gradient.py index f08ccb2a4..fb65f41e0 100644 --- a/demos/demo35_FD_gradient.py +++ b/demos/demo35_FD_gradient.py @@ -11,8 +11,10 @@ #%% Sample from the posterior using Metropolis-Hastings print('Sampling from the posterior using Metropolis-Hastings:') MH_sampler = cuqi.sampler.MH(posterior) -MH_samples = MH_sampler.sample_adapt(1000) +MH_samples = MH_sampler.sample_adapt(1000, 100) +plt.figure() MH_samples.plot_ci(95,exact=TP.exactSolution) +plt.title("MH") #%% Sample from the posterior using MALA print('Sampling from the posterior using MALA:') @@ -22,19 +24,37 @@ MALA_samples = MALA_sampler.sample_adapt(1000) except Exception as e: print(e) +print('Sampling failed because the gradient of the posterior is not available.') -print('Sampling failed because the gradient of the posterior is not available.') +#%% print('Enable finite difference approximation of the gradient ' + - 'for the posterior, and attempt sampling again using MALA:') + 'for the posterior, and attempt sampling again using MALA, ULA and NUTS:') posterior.enable_FD() MALA_sampler = cuqi.sampler.MALA(posterior, 0.0001) -MALA_samples = MALA_sampler.sample_adapt(1000) +MALA_samples = MALA_sampler.sample_adapt(1000, 10) plt.figure() MALA_samples.plot_ci(95,exact=TP.exactSolution) +plt.title("MALA") + +#%% Sample from the posterior using ULA +ULA_sampler = cuqi.sampler.ULA(posterior, 0.0001) +ULA_samples = ULA_sampler.sample_adapt(1000, 10) +plt.figure() +ULA_samples.plot_ci(95,exact=TP.exactSolution) +plt.title("ULA") + +#%% Sample from the posterior using NUTS +NUTS_sampler = cuqi.sampler.NUTS(posterior) +NUTS_samples = NUTS_sampler.sample_adapt(1000, 10) +plt.figure() +NUTS_samples.plot_ci(95, exact=TP.exactSolution) +plt.title("NUTS") -#%% Plot the ESS of the two chains +#%% Plot the ESS of all the chains plt.figure() plt.plot(MH_samples.compute_ess(), label='MH ESS', marker='o') plt.plot(MALA_samples.compute_ess(), label='MALA ESS', marker='o') -plt.legend() +plt.plot(ULA_samples.compute_ess(), label='ULA ESS', marker='o') +plt.plot(NUTS_samples.compute_ess(), label='NUTS ESS', marker='o') +plt.legend() \ No newline at end of file