Skip to content

Commit

Permalink
Fix other issues, but still lifetimes :(
Browse files Browse the repository at this point in the history
  • Loading branch information
mattxwang committed Jul 21, 2023
1 parent 5ac5764 commit d2db106
Show file tree
Hide file tree
Showing 2 changed files with 53 additions and 50 deletions.
59 changes: 29 additions & 30 deletions src/bayesian_network.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
use std::collections::{HashMap, HashSet};
use std::iter::FromIterator;
use std::{
collections::{HashMap, HashSet},
iter::FromIterator,
rc::Rc,
};

use rsdd::builder::bdd::{BddBuilder, RobddBuilder};
use rsdd::builder::cache::all_app::AllTable;
Expand All @@ -9,7 +12,6 @@ use rsdd::repr::cnf::Cnf;
use rsdd::repr::ddnnf::DDNNFPtr;
use rsdd::repr::var_label::{Literal, VarLabel};

use assert_approx_eq::assert_approx_eq;
use rsdd::repr::wmc::WmcParams;
use rsdd::util::semirings::{RealSemiring, Semiring};

Expand Down Expand Up @@ -81,21 +83,21 @@ impl Iterator for AssignmentIter {
fn next(&mut self) -> Option<Self::Item> {
if self.cur.is_none() {
self.cur = Some(self.vars.iter().map(|_| 0).collect());
return self.cur.clone();
self.cur.clone()
} else {
// attempt to do a binary increment of the current state
let cur_shape: Vec<usize> = self.vars.iter().map(|x| self.shape[*x]).collect();
for (idx, assgn) in self.cur.as_mut().unwrap().iter_mut().enumerate() {
if *assgn + 1 < cur_shape[idx] {
*assgn = *assgn + 1;
*assgn += 1;
return self.cur.clone();
} else {
// add and carry
*assgn = 0;
}
}
// we failed to add without carrying
return None;
None
}
}
}
Expand Down Expand Up @@ -132,22 +134,22 @@ impl BayesianNetwork {
for cpt in cpts.iter() {
add(&mut sorted, &mut added, &cpts, cpt);
}
return BayesianNetwork {
BayesianNetwork {
shape,
cpts: sorted,
};
}
}

fn cpts_topological(&self) -> std::slice::Iter<'_, CPT> {
fn cpts_topological(&self) -> std::slice::Iter<'a, CPT> {
return self.cpts.iter();
}

fn num_vars(&self) -> usize {
return self.shape.len();
self.shape.len()
}

fn get_shape(&self) -> &Vec<usize> {
return &self.shape;
&self.shape
}
}

Expand Down Expand Up @@ -186,7 +188,7 @@ pub struct CompiledBayesianNetwork<'a> {
}

impl<'a> CompiledBayesianNetwork<'a> {
pub fn new(bn: &BayesianNetwork, _mode: CompileMode) -> CompiledBayesianNetwork<'a> {
pub fn new(bn: &BayesianNetwork, mode: CompileMode) -> CompiledBayesianNetwork {
// the key is (var, state)
let mut varcount = 0;
// (var, value) -> indicator
Expand All @@ -203,15 +205,11 @@ impl<'a> CompiledBayesianNetwork<'a> {
let v = VarLabel::new(varcount);
varcount += 1;
indicators.insert((var_label, cur_value), v);
weight_table.insert(v.clone(), (1.0, 1.0));
weight_table.insert(v, (1.0, 1.0));
vars.push(v);
}
// build exactly-one constraint
clauses.push(
vars.iter()
.map(|var| Literal::new(var.clone(), true))
.collect(),
);
clauses.push(vars.iter().map(|var| Literal::new(*var, true)).collect());
for i in 0..vars.len() {
for j in i..vars.len() {
if i == j {
Expand All @@ -234,11 +232,11 @@ impl<'a> CompiledBayesianNetwork<'a> {
varcount += 1;
params.insert(assignment.clone(), v);
match prob {
&Probability::Concrete(p) => {
weight_table.insert(v.clone(), (1.0, p));
Probability::Concrete(p) => {
weight_table.insert(v, (1.0, *p));
}
&Probability::Symbol(vlbl) => {
symbolic_map.insert(vlbl, v);
Probability::Symbol(vlbl) => {
symbolic_map.insert(*vlbl, v);
}
};
}
Expand Down Expand Up @@ -272,7 +270,7 @@ impl<'a> CompiledBayesianNetwork<'a> {
let cnf = Cnf::new(clauses);
let builder =
RobddBuilder::<'a, AllTable<BddPtr<'a>>>::new_default_order(varcount as usize);
let bdd: BddPtr<'a> = builder.compile_cnf(&cnf);
let bdd = builder.compile_cnf(&cnf);

CompiledBayesianNetwork {
builder,
Expand All @@ -281,12 +279,12 @@ impl<'a> CompiledBayesianNetwork<'a> {
symbolic_map,
shape: bn.shape.clone(),
weights: weight_table,
mode: CompileMode::BottomUpChaviraDarwicheBDD,
mode,
}
}

pub fn get_shape(&'a self) -> &Vec<usize> {
return &self.shape;
&self.shape
}

/// Computes the joint marginal probability of the subset of variables `vars`
Expand Down Expand Up @@ -326,7 +324,7 @@ impl<'a> CompiledBayesianNetwork<'a> {
let wmc = cur_bdd.wmc(self.builder.get_order(), &wmc_param).0;
r.insert(assgn, wmc);
}
return r;
r
}
}

Expand All @@ -344,7 +342,7 @@ fn test_marginal_0() {
)];
let bn = BayesianNetwork::new(shape, cpts);
let compiled = CompiledBayesianNetwork::new(&bn, CompileMode::BottomUpChaviraDarwicheBDD);
let r = compiled.joint_marginal(&SymbolTable::empty(), &vec![0]);
let r = compiled.joint_marginal(&SymbolTable::empty(), &[0]);
assert_eq!(r[&vec![1]], 0.9);
assert_eq!(r[&vec![0]], 0.1);
}
Expand Down Expand Up @@ -375,7 +373,7 @@ fn test_marginal_1() {
];
let bn = BayesianNetwork::new(shape, cpts);
let compiled = CompiledBayesianNetwork::new(&bn, CompileMode::BottomUpChaviraDarwicheBDD);
let r = compiled.joint_marginal(&SymbolTable::empty(), &vec![0, 1]);
let r = compiled.joint_marginal(&SymbolTable::empty(), &[0, 1]);

assert_eq!(r[&vec![0, 0]], 0.3 * 0.1);
assert_eq!(r[&vec![0, 1]], 0.7 * 0.1);
Expand Down Expand Up @@ -411,7 +409,7 @@ fn test_marginal_2() {
];
let bn = BayesianNetwork::new(shape, cpts);
let compiled = CompiledBayesianNetwork::new(&bn, CompileMode::BottomUpChaviraDarwicheBDD);
let r = compiled.joint_marginal(&SymbolTable::empty(), &vec![0, 1]);
let r = compiled.joint_marginal(&SymbolTable::empty(), &[0, 1]);

assert_eq!(r[&vec![0, 2]], 0.1 * 0.5);
assert_eq!(r[&vec![1, 2]], 0.9 * 0.2);
Expand All @@ -420,6 +418,7 @@ fn test_marginal_2() {

#[test]
fn test_marginal_3() {
use assert_approx_eq::assert_approx_eq;
// BN : (a) -> (b) <- (c)
let shape = vec![2, 2, 2];
let cpts = vec![
Expand Down Expand Up @@ -454,7 +453,7 @@ fn test_marginal_3() {
];
let bn = BayesianNetwork::new(shape, cpts);
let compiled = CompiledBayesianNetwork::new(&bn, CompileMode::BottomUpChaviraDarwicheBDD);
let r = compiled.joint_marginal(&SymbolTable::empty(), &vec![0, 1, 2]);
let r = compiled.joint_marginal(&SymbolTable::empty(), &[0, 1, 2]);

assert_approx_eq!(r[&vec![0, 0, 0]], 0.1 * 0.2 * 0.1, 1e-3);
assert_approx_eq!(r[&vec![1, 0, 1]], 0.9 * 0.8 * 0.15, 1e-3);
Expand Down
44 changes: 24 additions & 20 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,13 @@ use bayesian_network::*;

#[no_mangle]
pub extern "C" fn make_concrete_prob(prob: f64) -> *mut Probability {
return Box::into_raw(Box::new(Probability::Concrete(prob)));
Box::into_raw(Box::new(Probability::Concrete(prob)))
}

/// Make a symbolic probability
#[no_mangle]
pub extern "C" fn make_symbolic_prob(label: usize) -> *mut Probability {
return Box::into_raw(Box::new(Probability::Symbol(label)));
Box::into_raw(Box::new(Probability::Symbol(label)))
}

/// Create a CPT
Expand All @@ -25,8 +25,9 @@ pub extern "C" fn make_symbolic_prob(label: usize) -> *mut Probability {
/// `var`
/// `probabilities`: list of probabilities corresponding to each element in
/// `assignments`
#[allow(clippy::missing_safety_doc)]
#[no_mangle]
pub extern "C" fn make_cpt(
pub unsafe extern "C" fn make_cpt(
var: usize,
num_parents: usize,
parents: *const usize,
Expand All @@ -52,11 +53,12 @@ pub extern "C" fn make_cpt(
let prob_unwrap: Vec<Probability> = probabilities.iter().map(|x| (**x).clone()).collect();
HashMap::from_iter(assign_vec.into_iter().zip(prob_unwrap))
};
return Box::into_raw(Box::new(CPT::new(var, parent_vec, prob)));
Box::into_raw(Box::new(CPT::new(var, parent_vec, prob)))
}

#[allow(clippy::missing_safety_doc)]
#[no_mangle]
pub extern "C" fn make_bayesian_network(
pub unsafe extern "C" fn make_bayesian_network(
num_vars: usize,
shape: *const usize,
num_cpts: usize,
Expand All @@ -67,48 +69,50 @@ pub extern "C" fn make_bayesian_network(
let cpt_ptrvec: &[*const CPT] = std::slice::from_raw_parts(cpts, num_cpts);
let cpt: Vec<CPT> = cpt_ptrvec.iter().map(|x| (**x).clone()).collect();
let bn = BayesianNetwork::new(shape, cpt);
return Box::into_raw(Box::new(bn));
Box::into_raw(Box::new(bn))
}
}

/// Construct a new symbol label that mapes the variable indexed by labels[i] to
/// weight probs[i]
#[allow(clippy::missing_safety_doc)]
#[no_mangle]
pub extern "C" fn make_symbol_table(
pub unsafe extern "C" fn make_symbol_table(
num_symbols: usize,
labels: *const usize,
probs: *const f64,
) -> *mut SymbolTable {
unsafe {
let labels = std::slice::from_raw_parts(labels, num_symbols);
let probs = std::slice::from_raw_parts(probs, num_symbols);
let map: HashMap<&usize, &f64> = HashMap::from_iter(labels.into_iter().zip(probs));
let map: HashMap<&usize, &f64> = HashMap::from_iter(labels.iter().zip(probs));
let map_into = map.iter().map(|(&key, &value)| (*key, *value)).collect();
return Box::into_raw(Box::new(SymbolTable::new(map_into)));
Box::into_raw(Box::new(SymbolTable::new(map_into)))
}
}

#[allow(clippy::missing_safety_doc)]
#[no_mangle]
pub extern "C" fn make_empty_symbol_table() -> *mut SymbolTable {
return Box::into_raw(Box::new(SymbolTable::empty()));
pub unsafe extern "C" fn make_empty_symbol_table() -> *mut SymbolTable {
Box::into_raw(Box::new(SymbolTable::empty()))
}

#[allow(clippy::missing_safety_doc)]
#[no_mangle]
pub extern "C" fn compile_bayesian_network(
pub unsafe extern "C" fn compile_bayesian_network(
bn: *const BayesianNetwork,
) -> *mut CompiledBayesianNetwork<'static> {
unsafe {
return Box::into_raw(Box::new(CompiledBayesianNetwork::new(
bn.as_ref().unwrap(),
CompileMode::BottomUpChaviraDarwicheBDD,
)));
}
Box::into_raw(Box::new(CompiledBayesianNetwork::new(
bn.as_ref().unwrap(),
CompileMode::BottomUpChaviraDarwicheBDD,
)))
}

/// Compute the joint marginal for the subset of variables given in `vars`
/// Gives ownership of the returned array to the caller
#[allow(clippy::missing_safety_doc)]
#[no_mangle]
pub extern "C" fn joint_marginal(
pub unsafe extern "C" fn joint_marginal(
bn: *mut CompiledBayesianNetwork,
st: *const SymbolTable,
num_vars: usize,
Expand All @@ -124,6 +128,6 @@ pub extern "C" fn joint_marginal(
}
let r = v.as_ptr();
std::mem::forget(v);
return r;
r
}
}

0 comments on commit d2db106

Please sign in to comment.