diff --git a/src/lib.rs b/src/lib.rs index ecfb2cc..46d8972 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -4,7 +4,7 @@ use rsdd::{ builder::{bdd::RobddBuilder, cache::AllIteTable, BottomUpBuilder}, constants::primes, repr::{BddPtr, Cnf, DDNNFPtr, PartialModel, VarLabel, VarOrder, WmcParams}, - util::semirings::{ExpectedUtility, FiniteField, Semiring}, + util::semirings::{ExpectedUtility, FiniteField, RealSemiring, Semiring}, }; #[ocaml::sig] @@ -52,10 +52,13 @@ pub fn mk_bdd_builder_default_order(num_vars: u64) -> ocaml::Pointer bool -> rsdd_bdd_ptr")] -pub fn bdd_new_var(builder: &'static RsddBddBuilder, polarity: bool) -> ocaml::Pointer { - let (_, ptr) = builder.0.new_var(polarity); - RsddBddPtr(ptr).into() +#[ocaml::sig("rsdd_bdd_builder -> bool -> (int64 * rsdd_bdd_ptr)")] +pub fn bdd_new_var( + builder: &'static RsddBddBuilder, + polarity: bool, +) -> (u64, ocaml::Pointer) { + let (lbl, ptr) = builder.0.new_var(polarity); + (lbl.value(), RsddBddPtr(ptr).into()) } #[ocaml::func] @@ -155,6 +158,36 @@ pub fn bdd_high(bdd: &RsddBddPtr) -> ocaml::Pointer { RsddBddPtr(bdd.0.high()).into() } +// real semiring + +#[ocaml::sig] +pub struct RsddWmcParamsR(WmcParams); +ocaml::custom!(RsddWmcParamsR); + +#[ocaml::func] +#[ocaml::sig("rsdd_bdd_ptr -> rsdd_wmc_params_r -> float")] +pub fn bdd_wmc(bdd: &RsddBddPtr, wmc: &RsddWmcParamsR) -> f64 { + DDNNFPtr::unsmoothed_wmc(&bdd.0, &wmc.0).0 +} + +#[ocaml::func] +#[ocaml::sig("(float * float) list -> rsdd_wmc_params_r")] +pub fn new_wmc_params_r(weights: ocaml::List<(f64, f64)>) -> ocaml::Pointer { + RsddWmcParamsR(WmcParams::new(HashMap::from_iter( + weights + .into_linked_list() + .iter() + .enumerate() + .map(|(index, (a, b))| { + ( + VarLabel::new(index as u64), + (RealSemiring(*a), RealSemiring(*b)), + ) + }), + ))) + .into() +} + // branch & bound, expected semiring items #[ocaml::sig] pub struct RsddExpectedUtility(ExpectedUtility); diff --git a/src/rsdd.ml b/src/rsdd.ml index fb2d171..f9c0b04 100644 --- a/src/rsdd.ml +++ b/src/rsdd.ml @@ -9,10 +9,11 @@ type rsdd_bdd_builder type rsdd_cnf type rsdd_partial_model type rsdd_var_label +type rsdd_wmc_params_r type rsdd_expected_utility type rsdd_wmc_params_e_u external mk_bdd_builder_default_order: int64 -> rsdd_bdd_builder = "mk_bdd_builder_default_order" -external bdd_new_var: rsdd_bdd_builder -> bool -> rsdd_bdd_ptr = "bdd_new_var" +external bdd_new_var: rsdd_bdd_builder -> bool -> (int64 * rsdd_bdd_ptr) = "bdd_new_var" external bdd_ite: rsdd_bdd_builder -> rsdd_bdd_ptr -> rsdd_bdd_ptr -> rsdd_bdd_ptr -> rsdd_bdd_ptr = "bdd_ite" external bdd_and: rsdd_bdd_builder -> rsdd_bdd_ptr -> rsdd_bdd_ptr -> rsdd_bdd_ptr = "bdd_and" external bdd_or: rsdd_bdd_builder -> rsdd_bdd_ptr -> rsdd_bdd_ptr -> rsdd_bdd_ptr = "bdd_or" @@ -26,6 +27,8 @@ external bdd_eq: rsdd_bdd_builder -> rsdd_bdd_ptr -> rsdd_bdd_ptr -> bool = "bdd external bdd_topvar: rsdd_bdd_ptr -> int64 = "bdd_topvar" external bdd_low: rsdd_bdd_ptr -> rsdd_bdd_ptr = "bdd_low" external bdd_high: rsdd_bdd_ptr -> rsdd_bdd_ptr = "bdd_high" +external bdd_wmc: rsdd_bdd_ptr -> rsdd_wmc_params_r -> float = "bdd_wmc" +external new_wmc_params_r: (float * float) list -> rsdd_wmc_params_r = "new_wmc_params_r" external bdd_bb: rsdd_bdd_ptr -> rsdd_var_label list -> int64 -> rsdd_wmc_params_e_u -> rsdd_expected_utility * rsdd_partial_model = "bdd_bb" external bdd_meu: rsdd_bdd_ptr -> rsdd_var_label list -> int64 -> rsdd_wmc_params_e_u -> rsdd_expected_utility * rsdd_partial_model = "bdd_meu" external new_wmc_params_eu: ((float * float) * (float * float)) list -> rsdd_wmc_params_e_u = "new_wmc_params_eu" diff --git a/src/rsdd.mli b/src/rsdd.mli index fb2d171..f9c0b04 100644 --- a/src/rsdd.mli +++ b/src/rsdd.mli @@ -9,10 +9,11 @@ type rsdd_bdd_builder type rsdd_cnf type rsdd_partial_model type rsdd_var_label +type rsdd_wmc_params_r type rsdd_expected_utility type rsdd_wmc_params_e_u external mk_bdd_builder_default_order: int64 -> rsdd_bdd_builder = "mk_bdd_builder_default_order" -external bdd_new_var: rsdd_bdd_builder -> bool -> rsdd_bdd_ptr = "bdd_new_var" +external bdd_new_var: rsdd_bdd_builder -> bool -> (int64 * rsdd_bdd_ptr) = "bdd_new_var" external bdd_ite: rsdd_bdd_builder -> rsdd_bdd_ptr -> rsdd_bdd_ptr -> rsdd_bdd_ptr -> rsdd_bdd_ptr = "bdd_ite" external bdd_and: rsdd_bdd_builder -> rsdd_bdd_ptr -> rsdd_bdd_ptr -> rsdd_bdd_ptr = "bdd_and" external bdd_or: rsdd_bdd_builder -> rsdd_bdd_ptr -> rsdd_bdd_ptr -> rsdd_bdd_ptr = "bdd_or" @@ -26,6 +27,8 @@ external bdd_eq: rsdd_bdd_builder -> rsdd_bdd_ptr -> rsdd_bdd_ptr -> bool = "bdd external bdd_topvar: rsdd_bdd_ptr -> int64 = "bdd_topvar" external bdd_low: rsdd_bdd_ptr -> rsdd_bdd_ptr = "bdd_low" external bdd_high: rsdd_bdd_ptr -> rsdd_bdd_ptr = "bdd_high" +external bdd_wmc: rsdd_bdd_ptr -> rsdd_wmc_params_r -> float = "bdd_wmc" +external new_wmc_params_r: (float * float) list -> rsdd_wmc_params_r = "new_wmc_params_r" external bdd_bb: rsdd_bdd_ptr -> rsdd_var_label list -> int64 -> rsdd_wmc_params_e_u -> rsdd_expected_utility * rsdd_partial_model = "bdd_bb" external bdd_meu: rsdd_bdd_ptr -> rsdd_var_label list -> int64 -> rsdd_wmc_params_e_u -> rsdd_expected_utility * rsdd_partial_model = "bdd_meu" external new_wmc_params_eu: ((float * float) * (float * float)) list -> rsdd_wmc_params_e_u = "new_wmc_params_eu"