Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix nuts v4 #545

Open
wants to merge 11 commits into
base: main
Choose a base branch
from
Open

Fix nuts v4 #545

wants to merge 11 commits into from

Conversation

amal-ghamdi
Copy link
Contributor

@amal-ghamdi amal-ghamdi commented Oct 6, 2024

closes #524
closes #537

Verification Code 1:

Code verifies this change still gives same results as old NUTS (also this is verified by regression tests)

from git import Repo
repo = Repo('./CUQIpy')
branch = repo.active_branch
# Sample with NUTS + time it
import time
import numpy as np
from cuqi.testproblem import Deconvolution1D
from cuqi.distribution import Gaussian, Gamma, JointDistribution, GMRF
from cuqi.experimental.mcmc import NUTS
import matplotlib.pyplot as plt

# Inverse problem
np.random.seed(0) 
testproblem_obj = Deconvolution1D(dim=128, phantom='sinc', noise_std=0.001)
posterior = testproblem_obj.posterior

# NUTS
Ns=15
Nb=15
print(branch.name)
print('NUTS')
start = time.time()
np.random.seed(0)
NUTS_sampler = NUTS(posterior, step_size=None)
NUTS_sampler.warmup(Nb, tune_freq=1/Nb).sample(Ns)

end = time.time()
print('Time:', end - start)
print('samples norm:', np.linalg.norm(NUTS_sampler.get_samples().samples))
plt.plot([np.linalg.norm(NUTS_sampler.get_samples().samples[:,i]) for i in range(Ns+Nb)])
plt.ylim([0, 14])
# add branch name in title
plt.title(branch.name)
plt.figure()
plt.plot(NUTS_sampler.num_tree_node_list)
plt.title(branch.name)

Results main branch:

nuts_main_chain
nuts_main_tree

Results this branch:
nuts_branch_chain
nuts_branch_tree

Verification Code 2 (added as a test in the PR, for smaller sample size):

import cuqi
import numpy as np
from cuqi.distribution import Gamma, Gaussian, GMRF, JointDistribution, LMRF
from cuqi.experimental.mcmc import NUTS, HybridGibbs, Conjugate, LinearRTO, ConjugateApprox, UGLA
from cuqi.testproblem import Deconvolution1D
from git import Repo
import matplotlib.pyplot as plt
import time
repo = Repo('../')
branch = repo.active_branch

# Forward problem
np.random.seed(0)
A, y_data, info = Deconvolution1D(dim=128, phantom='sinc', noise_std=0.001).get_components()

# Bayesian Inverse Problem
s = Gamma(1, 1e-4)
x = GMRF(np.zeros(A.domain_dim), 50)
y = Gaussian(A@x, lambda s: 1/s)

# Posterior
target = JointDistribution(y, x, s)(y=y_data)

Nb=40
sampling_strategy = {
    "x" : NUTS(max_depth=7),
    "s" : Conjugate()
}

# Here we do 10 internal steps with NUTS for each Gibbs step
num_sampling_steps = {
    "x" : 1,
    "s" : 1
}

sampler = HybridGibbs(target, sampling_strategy, num_sampling_steps)
# start time
start_time = time.time()
sampler.warmup(Nb)
sampler.sample(40)
samples = sampler.get_samples()
# end time
end_time = time.time()
print(end_time - start_time)


samples["x"].plot_ci(exact=info.exactSolution)
plt.title(branch.name +" time: "+str(end_time - start_time))

results main branch

gibbs_main

results this branch

gibbs_branch

@amal-ghamdi amal-ghamdi mentioned this pull request Oct 6, 2024
2 tasks
reference_x = reference["x"]

# Compare samples
assert np.allclose(samples["s"].samples, reference_s, rtol=1e-3)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

note: on my machine, I did not need to set rtol to anything, default value worked. I needed to set it to 1e-3 to pass github actions

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Only fair. NUTS has been difficult to test.

Copy link
Collaborator

@nabriis nabriis left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @amal-ghamdi. Really nice! I only had a comment related to the deepcopy. Perhaps we can find a more efficient method.

Comment on lines -139 to -145
self._initialize_samplers()

# Run over pre-sample methods for samplers that have it
# TODO. Some samplers (NUTS) seem to require to run _pre_warmup before _pre_sample
# This is not ideal and should be fixed in the future
for sampler in self.samplers.values():
self._pre_warmup_and_pre_sample_sampler(sampler)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🎉

@@ -212,9 +216,9 @@ def step(self):
self._num_tree_node = 0

# copy current point, logd, and grad in local variables
point_k = self.current_point.copy() # initial position (parameters)
point_k = deepcopy(self.current_point) # initial position (parameters)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are you sure we need this deep copy? One may try a more efficient operation like the "slicing trick" self.current_point[:]. We can wrap this in a method for NUTS like "self._copy_array()" or something.

it would be good to have a bit of explanation as to why it is needed.

Copy link
Contributor Author

@amal-ghamdi amal-ghamdi Oct 10, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you @nabriis , very good point. I think deep copy made a bit of an overhead, maybe not significant and also not significant relative to computation time. For example the Gibbs example used to run in 19.6 Sec but after the modification with deep copy it runs in 20.7 Sec.

To address this, I revert to using numpy.copy which I read is more efficient than deep copy. I handelded cases where the value could be scalar or array with if else.

The reason I used deepcopy at the first place is because some values like logd sometimes is array and sometimes is scalar.

reference_x = reference["x"]

# Compare samples
assert np.allclose(samples["s"].samples, reference_s, rtol=1e-3)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Only fair. NUTS has been difficult to test.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
2 participants