Skip to content

Commit

Permalink
Replace u64 with i64 in client-facing code, force conversions (#6)
Browse files Browse the repository at this point in the history
  • Loading branch information
mattxwang authored Oct 5, 2023
1 parent 9c02124 commit 2627bea
Showing 1 changed file with 14 additions and 14 deletions.
28 changes: 14 additions & 14 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -36,15 +36,15 @@ unsafe impl ocaml::ToValue for RsddVarLabel {
unsafe impl ocaml::FromValue for RsddVarLabel {
fn from_value(v: ocaml::Value) -> Self {
let i = unsafe { v.int64_val() };
RsddVarLabel(VarLabel::new(i as u64))
RsddVarLabel(VarLabel::new(i.try_into().unwrap()))
}
}

// disc/dice interface

#[ocaml::func]
#[ocaml::sig("int64 -> rsdd_bdd_builder")]
pub fn mk_bdd_builder_default_order(num_vars: u64) -> ocaml::Pointer<RsddBddBuilder> {
pub fn mk_bdd_builder_default_order(num_vars: i64) -> ocaml::Pointer<RsddBddBuilder> {
RsddBddBuilder(RobddBuilder::<AllIteTable<BddPtr>>::new(
VarOrder::linear_order(num_vars as usize),
))
Expand All @@ -56,9 +56,9 @@ pub fn mk_bdd_builder_default_order(num_vars: u64) -> ocaml::Pointer<RsddBddBuil
pub fn bdd_new_var(
builder: &'static RsddBddBuilder,
polarity: bool,
) -> (u64, ocaml::Pointer<RsddBddPtr>) {
) -> (i64, ocaml::Pointer<RsddBddPtr>) {
let (lbl, ptr) = builder.0.new_var(polarity);
(lbl.value(), RsddBddPtr(ptr).into())
(lbl.value().try_into().unwrap(), RsddBddPtr(ptr).into())
}

#[ocaml::func]
Expand Down Expand Up @@ -139,10 +139,10 @@ pub fn bdd_eq(builder: &'static RsddBddBuilder, a: &RsddBddPtr, b: &RsddBddPtr)

#[ocaml::func]
#[ocaml::sig("rsdd_bdd_ptr -> int64")]
pub fn bdd_topvar(bdd: &RsddBddPtr) -> u64 {
pub fn bdd_topvar(bdd: &RsddBddPtr) -> i64 {
match (bdd.0).var_safe() {
Some(x) => x.value(),
None => 0, // TODO: provide a better version for this, maybe a Maybe/Option?
Some(x) => x.value().try_into().unwrap(),
None => -1, // TODO: provide a better version for this, maybe a Maybe/Option?
}
}

Expand Down Expand Up @@ -180,7 +180,7 @@ pub fn new_wmc_params_r(weights: ocaml::List<(f64, f64)>) -> ocaml::Pointer<Rsdd
.enumerate()
.map(|(index, (a, b))| {
(
VarLabel::new(index as u64),
VarLabel::new(index.try_into().unwrap()),
(RealSemiring(*a), RealSemiring(*b)),
)
}),
Expand All @@ -202,7 +202,7 @@ ocaml::custom!(RsddWmcParamsEU);
pub fn bdd_bb(
bdd: &'static RsddBddPtr,
join_vars: ocaml::List<RsddVarLabel>,
num_vars: u64,
num_vars: i64,
wmc: &RsddWmcParamsEU,
) -> (
ocaml::Pointer<RsddExpectedUtility>,
Expand All @@ -225,7 +225,7 @@ pub fn bdd_bb(
pub fn bdd_meu(
bdd: &'static RsddBddPtr,
decision_vars: ocaml::List<RsddVarLabel>,
num_vars: u64,
num_vars: i64,
wmc: &RsddWmcParamsEU,
) -> (
ocaml::Pointer<RsddExpectedUtility>,
Expand Down Expand Up @@ -255,7 +255,7 @@ pub fn new_wmc_params_eu(
.enumerate()
.map(|(index, (a, b))| {
(
VarLabel::new(index as u64),
VarLabel::new(index.try_into().unwrap()),
(ExpectedUtility(a.0, a.1), ExpectedUtility(b.0, b.1)),
)
}),
Expand All @@ -282,15 +282,15 @@ pub fn bdd_builder_compile_cnf(

#[ocaml::func]
#[ocaml::sig("rsdd_bdd_builder -> rsdd_bdd_ptr -> int64")]
pub fn bdd_model_count(builder: &'static RsddBddBuilder, bdd: &'static RsddBddPtr) -> u64 {
pub fn bdd_model_count(builder: &'static RsddBddBuilder, bdd: &'static RsddBddPtr) -> i64 {
let num_vars = builder.0.num_vars();
let smoothed = builder.0.smooth(bdd.0, num_vars);
let unweighted_params: WmcParams<FiniteField<{ primes::U64_LARGEST }>> =
WmcParams::new(HashMap::from_iter(
(0..num_vars as u64)
(0..num_vars.try_into().unwrap())
.map(|v| (VarLabel::new(v), (FiniteField::one(), FiniteField::one()))),
));

let mc = smoothed.unsmoothed_wmc(&unweighted_params).value();
mc as u64
mc.try_into().unwrap()
}

0 comments on commit 2627bea

Please sign in to comment.