Skip to content

Commit

Permalink
feat: Add low rank modified mass matrix adaptation
Browse files Browse the repository at this point in the history
  • Loading branch information
aseyboldt committed Jul 5, 2024
1 parent c2e6ea6 commit 9a477fa
Show file tree
Hide file tree
Showing 6 changed files with 340 additions and 124 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -11,3 +11,4 @@ notebooks/*.hpp
perf.data*
wheels
.vscode/
*~
72 changes: 35 additions & 37 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

4 changes: 2 additions & 2 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ name = "_lib"
crate-type = ["cdylib"]

[dependencies]
nuts-rs = "0.11.0"
nuts-rs = "0.12.0"
numpy = "0.21.0"
ndarray = "0.15.6"
rand = "0.8.5"
Expand All @@ -33,7 +33,7 @@ rayon = "1.9.0"
arrow = { version = "52.0.0", default-features = false, features = ["ffi"] }
anyhow = "1.0.72"
itertools = "0.13.0"
bridgestan = "2.4.1"
bridgestan = "2.5.0"
rand_distr = "0.4.3"
smallvec = "1.11.0"
upon = { version = "0.8.1", default-features = false, features = [] }
Expand Down
12 changes: 11 additions & 1 deletion python/nutpie/sample.py
Original file line number Diff line number Diff line change
Expand Up @@ -460,6 +460,7 @@ def sample(
seed: Optional[int],
save_warmup: bool,
progress_bar: bool,
low_rank_modified_mass_matrix: bool = False,
init_mean: Optional[np.ndarray],
return_raw_trace: bool,
blocking: Literal[True],
Expand All @@ -478,6 +479,7 @@ def sample(
seed: Optional[int],
save_warmup: bool,
progress_bar: bool,
low_rank_modified_mass_matrix: bool = False,
init_mean: Optional[np.ndarray],
return_raw_trace: bool,
blocking: Literal[False],
Expand All @@ -495,6 +497,7 @@ def sample(
seed: Optional[int] = None,
save_warmup: bool = True,
progress_bar: bool = True,
low_rank_modified_mass_matrix: bool = False,
init_mean: Optional[np.ndarray] = None,
return_raw_trace: bool = False,
blocking: bool = True,
Expand Down Expand Up @@ -569,6 +572,9 @@ def sample(
for the progress bar (eg CSS).
progress_rate: int, default=500
Rate in ms at which the progress should be updated.
low_rank_modified_mass_matrix: bool, default=False
Allow adaptation to some posterior correlations using
a low-rank updated mass matrix.
**kwargs
Pass additional arguments to nutpie._lib.PySamplerArgs
Expand All @@ -577,7 +583,11 @@ def sample(
trace : arviz.InferenceData
An ArviZ ``InferenceData`` object that contains the samples.
"""
settings = _lib.PyDiagGradNutsSettings(seed)

if low_rank_modified_mass_matrix:
settings = _lib.PyNutsSettings.LowRank(seed)
else:
settings = _lib.PyNutsSettings.Diag(seed)
settings.num_tune = tune
settings.num_draws = draws
settings.num_chains = chains
Expand Down
12 changes: 6 additions & 6 deletions src/pyfunc.rs
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ impl PyVariable {
let field = Arc::new(Field::new("item", DataType::Boolean, false));
DataType::FixedSizeList(field, tensor_type.size() as i32)
}
ExpandDtype::ArrayFloat64 { tensor_type } => {
ExpandDtype::ArrayFloat64 { tensor_type: _ } => {
let field = Arc::new(Field::new("item", DataType::Float64, true));
DataType::List(field)
}
Expand Down Expand Up @@ -303,11 +303,11 @@ impl ExpandDtype {
#[getter]
fn shape(&self) -> Option<Vec<usize>> {
match self {
Self::BooleanArray {tensor_type} => { Some(tensor_type.shape.iter().cloned().collect()) },
Self::ArrayFloat64 {tensor_type} => { Some(tensor_type.shape.iter().cloned().collect()) },
Self::ArrayFloat32 {tensor_type} => { Some(tensor_type.shape.iter().cloned().collect()) },
Self::ArrayInt64 {tensor_type} => { Some(tensor_type.shape.iter().cloned().collect()) },
_ => { None },
Self::BooleanArray { tensor_type } => Some(tensor_type.shape.iter().cloned().collect()),
Self::ArrayFloat64 { tensor_type } => Some(tensor_type.shape.iter().cloned().collect()),
Self::ArrayFloat32 { tensor_type } => Some(tensor_type.shape.iter().cloned().collect()),
Self::ArrayInt64 { tensor_type } => Some(tensor_type.shape.iter().cloned().collect()),
_ => None,
}
}
}
Expand Down
Loading

0 comments on commit 9a477fa

Please sign in to comment.