-
Notifications
You must be signed in to change notification settings - Fork 9
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Fix nuts v4 #545
base: main
Are you sure you want to change the base?
Fix nuts v4 #545
Conversation
reference_x = reference["x"] | ||
|
||
# Compare samples | ||
assert np.allclose(samples["s"].samples, reference_s, rtol=1e-3) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
note: on my machine, I did not need to set rtol to anything, default value worked. I needed to set it to 1e-3 to pass github actions
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Only fair. NUTS has been difficult to test.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hi @amal-ghamdi. Really nice! I only had a comment related to the deepcopy. Perhaps we can find a more efficient method.
self._initialize_samplers() | ||
|
||
# Run over pre-sample methods for samplers that have it | ||
# TODO. Some samplers (NUTS) seem to require to run _pre_warmup before _pre_sample | ||
# This is not ideal and should be fixed in the future | ||
for sampler in self.samplers.values(): | ||
self._pre_warmup_and_pre_sample_sampler(sampler) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🎉
cuqi/experimental/mcmc/_hmc.py
Outdated
@@ -212,9 +216,9 @@ def step(self): | |||
self._num_tree_node = 0 | |||
|
|||
# copy current point, logd, and grad in local variables | |||
point_k = self.current_point.copy() # initial position (parameters) | |||
point_k = deepcopy(self.current_point) # initial position (parameters) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Are you sure we need this deep copy? One may try a more efficient operation like the "slicing trick" self.current_point[:]. We can wrap this in a method for NUTS like "self._copy_array()" or something.
it would be good to have a bit of explanation as to why it is needed.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you @nabriis , very good point. I think deep copy made a bit of an overhead, maybe not significant and also not significant relative to computation time. For example the Gibbs example used to run in 19.6 Sec but after the modification with deep copy it runs in 20.7 Sec.
To address this, I revert to using numpy.copy which I read is more efficient than deep copy. I handelded cases where the value could be scalar or array with if else.
The reason I used deepcopy at the first place is because some values like logd sometimes is array and sometimes is scalar.
reference_x = reference["x"] | ||
|
||
# Compare samples | ||
assert np.allclose(samples["s"].samples, reference_s, rtol=1e-3) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Only fair. NUTS has been difficult to test.
closes #524
closes #537
Verification Code 1:
Code verifies this change still gives same results as old NUTS (also this is verified by regression tests)
Results main branch:
Results this branch:
Verification Code 2 (added as a test in the PR, for smaller sample size):
results main branch
results this branch