Skip to content

Commit

Permalink
Add WMC for RealSemiring, change VarLabel return value (#4)
Browse files Browse the repository at this point in the history
Co-authored-by: Minsung Cho <[email protected]>
  • Loading branch information
mattxwang and minsungc committed Sep 14, 2023
1 parent 19b2181 commit 2698ca3
Show file tree
Hide file tree
Showing 3 changed files with 46 additions and 7 deletions.
43 changes: 38 additions & 5 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ use rsdd::{
builder::{bdd::RobddBuilder, cache::AllIteTable, BottomUpBuilder},
constants::primes,
repr::{BddPtr, Cnf, DDNNFPtr, PartialModel, VarLabel, VarOrder, WmcParams},
util::semirings::{ExpectedUtility, FiniteField, Semiring},
util::semirings::{ExpectedUtility, FiniteField, RealSemiring, Semiring},
};

#[ocaml::sig]
Expand Down Expand Up @@ -52,10 +52,13 @@ pub fn mk_bdd_builder_default_order(num_vars: u64) -> ocaml::Pointer<RsddBddBuil
}

#[ocaml::func]
#[ocaml::sig("rsdd_bdd_builder -> bool -> rsdd_bdd_ptr")]
pub fn bdd_new_var(builder: &'static RsddBddBuilder, polarity: bool) -> ocaml::Pointer<RsddBddPtr> {
let (_, ptr) = builder.0.new_var(polarity);
RsddBddPtr(ptr).into()
#[ocaml::sig("rsdd_bdd_builder -> bool -> (int64 * rsdd_bdd_ptr)")]
pub fn bdd_new_var(
builder: &'static RsddBddBuilder,
polarity: bool,
) -> (u64, ocaml::Pointer<RsddBddPtr>) {
let (lbl, ptr) = builder.0.new_var(polarity);
(lbl.value(), RsddBddPtr(ptr).into())
}

#[ocaml::func]
Expand Down Expand Up @@ -155,6 +158,36 @@ pub fn bdd_high(bdd: &RsddBddPtr) -> ocaml::Pointer<RsddBddPtr> {
RsddBddPtr(bdd.0.high()).into()
}

// real semiring

#[ocaml::sig]
pub struct RsddWmcParamsR(WmcParams<RealSemiring>);
ocaml::custom!(RsddWmcParamsR);

#[ocaml::func]
#[ocaml::sig("rsdd_bdd_ptr -> rsdd_wmc_params_r -> float")]
pub fn bdd_wmc(bdd: &RsddBddPtr, wmc: &RsddWmcParamsR) -> f64 {
DDNNFPtr::unsmoothed_wmc(&bdd.0, &wmc.0).0
}

#[ocaml::func]
#[ocaml::sig("(float * float) list -> rsdd_wmc_params_r")]
pub fn new_wmc_params_r(weights: ocaml::List<(f64, f64)>) -> ocaml::Pointer<RsddWmcParamsR> {
RsddWmcParamsR(WmcParams::new(HashMap::from_iter(
weights
.into_linked_list()
.iter()
.enumerate()
.map(|(index, (a, b))| {
(
VarLabel::new(index as u64),
(RealSemiring(*a), RealSemiring(*b)),
)
}),
)))
.into()
}

// branch & bound, expected semiring items
#[ocaml::sig]
pub struct RsddExpectedUtility(ExpectedUtility);
Expand Down
5 changes: 4 additions & 1 deletion src/rsdd.ml
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,11 @@ type rsdd_bdd_builder
type rsdd_cnf
type rsdd_partial_model
type rsdd_var_label
type rsdd_wmc_params_r
type rsdd_expected_utility
type rsdd_wmc_params_e_u
external mk_bdd_builder_default_order: int64 -> rsdd_bdd_builder = "mk_bdd_builder_default_order"
external bdd_new_var: rsdd_bdd_builder -> bool -> rsdd_bdd_ptr = "bdd_new_var"
external bdd_new_var: rsdd_bdd_builder -> bool -> (int64 * rsdd_bdd_ptr) = "bdd_new_var"
external bdd_ite: rsdd_bdd_builder -> rsdd_bdd_ptr -> rsdd_bdd_ptr -> rsdd_bdd_ptr -> rsdd_bdd_ptr = "bdd_ite"
external bdd_and: rsdd_bdd_builder -> rsdd_bdd_ptr -> rsdd_bdd_ptr -> rsdd_bdd_ptr = "bdd_and"
external bdd_or: rsdd_bdd_builder -> rsdd_bdd_ptr -> rsdd_bdd_ptr -> rsdd_bdd_ptr = "bdd_or"
Expand All @@ -26,6 +27,8 @@ external bdd_eq: rsdd_bdd_builder -> rsdd_bdd_ptr -> rsdd_bdd_ptr -> bool = "bdd
external bdd_topvar: rsdd_bdd_ptr -> int64 = "bdd_topvar"
external bdd_low: rsdd_bdd_ptr -> rsdd_bdd_ptr = "bdd_low"
external bdd_high: rsdd_bdd_ptr -> rsdd_bdd_ptr = "bdd_high"
external bdd_wmc: rsdd_bdd_ptr -> rsdd_wmc_params_r -> float = "bdd_wmc"
external new_wmc_params_r: (float * float) list -> rsdd_wmc_params_r = "new_wmc_params_r"
external bdd_bb: rsdd_bdd_ptr -> rsdd_var_label list -> int64 -> rsdd_wmc_params_e_u -> rsdd_expected_utility * rsdd_partial_model = "bdd_bb"
external bdd_meu: rsdd_bdd_ptr -> rsdd_var_label list -> int64 -> rsdd_wmc_params_e_u -> rsdd_expected_utility * rsdd_partial_model = "bdd_meu"
external new_wmc_params_eu: ((float * float) * (float * float)) list -> rsdd_wmc_params_e_u = "new_wmc_params_eu"
Expand Down
5 changes: 4 additions & 1 deletion src/rsdd.mli
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,11 @@ type rsdd_bdd_builder
type rsdd_cnf
type rsdd_partial_model
type rsdd_var_label
type rsdd_wmc_params_r
type rsdd_expected_utility
type rsdd_wmc_params_e_u
external mk_bdd_builder_default_order: int64 -> rsdd_bdd_builder = "mk_bdd_builder_default_order"
external bdd_new_var: rsdd_bdd_builder -> bool -> rsdd_bdd_ptr = "bdd_new_var"
external bdd_new_var: rsdd_bdd_builder -> bool -> (int64 * rsdd_bdd_ptr) = "bdd_new_var"
external bdd_ite: rsdd_bdd_builder -> rsdd_bdd_ptr -> rsdd_bdd_ptr -> rsdd_bdd_ptr -> rsdd_bdd_ptr = "bdd_ite"
external bdd_and: rsdd_bdd_builder -> rsdd_bdd_ptr -> rsdd_bdd_ptr -> rsdd_bdd_ptr = "bdd_and"
external bdd_or: rsdd_bdd_builder -> rsdd_bdd_ptr -> rsdd_bdd_ptr -> rsdd_bdd_ptr = "bdd_or"
Expand All @@ -26,6 +27,8 @@ external bdd_eq: rsdd_bdd_builder -> rsdd_bdd_ptr -> rsdd_bdd_ptr -> bool = "bdd
external bdd_topvar: rsdd_bdd_ptr -> int64 = "bdd_topvar"
external bdd_low: rsdd_bdd_ptr -> rsdd_bdd_ptr = "bdd_low"
external bdd_high: rsdd_bdd_ptr -> rsdd_bdd_ptr = "bdd_high"
external bdd_wmc: rsdd_bdd_ptr -> rsdd_wmc_params_r -> float = "bdd_wmc"
external new_wmc_params_r: (float * float) list -> rsdd_wmc_params_r = "new_wmc_params_r"
external bdd_bb: rsdd_bdd_ptr -> rsdd_var_label list -> int64 -> rsdd_wmc_params_e_u -> rsdd_expected_utility * rsdd_partial_model = "bdd_bb"
external bdd_meu: rsdd_bdd_ptr -> rsdd_var_label list -> int64 -> rsdd_wmc_params_e_u -> rsdd_expected_utility * rsdd_partial_model = "bdd_meu"
external new_wmc_params_eu: ((float * float) * (float * float)) list -> rsdd_wmc_params_e_u = "new_wmc_params_eu"
Expand Down

0 comments on commit 2698ca3

Please sign in to comment.