Skip to content

Commit

Permalink
Merge pull request #307 from CUQI-DTU/sprint22_check_NUTS_with_FD_gra…
Browse files Browse the repository at this point in the history
…dient

update demos/demo35_FD_gradient.py to test NUTS and ULA with FD approximated gradients
  • Loading branch information
chaozg authored Nov 9, 2023
2 parents 9ec01c9 + 9e4c98b commit dab78f8
Showing 1 changed file with 26 additions and 6 deletions.
32 changes: 26 additions & 6 deletions demos/demo35_FD_gradient.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,10 @@
#%% Sample from the posterior using Metropolis-Hastings
print('Sampling from the posterior using Metropolis-Hastings:')
MH_sampler = cuqi.sampler.MH(posterior)
MH_samples = MH_sampler.sample_adapt(1000)
MH_samples = MH_sampler.sample_adapt(1000, 100)
plt.figure()
MH_samples.plot_ci(95,exact=TP.exactSolution)
plt.title("MH")

#%% Sample from the posterior using MALA
print('Sampling from the posterior using MALA:')
Expand All @@ -22,19 +24,37 @@
MALA_samples = MALA_sampler.sample_adapt(1000)
except Exception as e:
print(e)
print('Sampling failed because the gradient of the posterior is not available.')

print('Sampling failed because the gradient of the posterior is not available.')
#%%
print('Enable finite difference approximation of the gradient ' +
'for the posterior, and attempt sampling again using MALA:')
'for the posterior, and attempt sampling again using MALA, ULA and NUTS:')

posterior.enable_FD()
MALA_sampler = cuqi.sampler.MALA(posterior, 0.0001)
MALA_samples = MALA_sampler.sample_adapt(1000)
MALA_samples = MALA_sampler.sample_adapt(1000, 10)
plt.figure()
MALA_samples.plot_ci(95,exact=TP.exactSolution)
plt.title("MALA")

#%% Sample from the posterior using ULA
ULA_sampler = cuqi.sampler.ULA(posterior, 0.0001)
ULA_samples = ULA_sampler.sample_adapt(1000, 10)
plt.figure()
ULA_samples.plot_ci(95,exact=TP.exactSolution)
plt.title("ULA")

#%% Sample from the posterior using NUTS
NUTS_sampler = cuqi.sampler.NUTS(posterior)
NUTS_samples = NUTS_sampler.sample_adapt(1000, 10)
plt.figure()
NUTS_samples.plot_ci(95, exact=TP.exactSolution)
plt.title("NUTS")

#%% Plot the ESS of the two chains
#%% Plot the ESS of all the chains
plt.figure()
plt.plot(MH_samples.compute_ess(), label='MH ESS', marker='o')
plt.plot(MALA_samples.compute_ess(), label='MALA ESS', marker='o')
plt.legend()
plt.plot(ULA_samples.compute_ess(), label='ULA ESS', marker='o')
plt.plot(NUTS_samples.compute_ess(), label='NUTS ESS', marker='o')
plt.legend()

0 comments on commit dab78f8

Please sign in to comment.