diff --git a/egg-pre-dcp/src/extract.rs b/egg-pre-dcp/src/extract.rs index c586b6ea..1ac2e200 100644 --- a/egg-pre-dcp/src/extract.rs +++ b/egg-pre-dcp/src/extract.rs @@ -294,14 +294,28 @@ pub fn get_steps_from_string_maybe_node_limit( // Iterations data. let iterations = runner.iterations; - let num_of_iterations = iterations.len(); + let num_of_iterations = iterations.len() - 1; + println!("Number of iterations: {:?}.", num_of_iterations); + let mut num_rules_applied = 0; + let mut num_iter = 0; for iteration in iterations { - for (_, count) in iteration.applied.iter() { + let mut max_count = 0; + let mut max_rule = Symbol::from(""); + for (name, count) in iteration.applied.iter() { num_rules_applied += count; + if count.clone() > max_count { + max_count = count.clone(); + max_rule = name.clone(); + } } + println!("--- Iteration {:?} data (cumulative) ---", num_iter); + println!("Rewrites applied: {:?}", num_rules_applied); + println!("E-nodes at start: {:?}", iteration.egraph_nodes); + println!("Max rule applied: {:?} (count: {:?}).", max_rule, max_count); + num_iter += 1; } - println!("Number of iterations: {:?}.", num_of_iterations); + println!("---"); println!("Number of rules applied: {:?}.", num_rules_applied); } } else { diff --git a/egg-pre-dcp/src/optimization.rs b/egg-pre-dcp/src/optimization.rs index e717652f..325fa11d 100644 --- a/egg-pre-dcp/src/optimization.rs +++ b/egg-pre-dcp/src/optimization.rs @@ -54,9 +54,15 @@ pub struct Meta { pub domains : HashMap, } +#[derive(Debug, Clone, PartialEq, Eq)] +pub enum TermType { + Problem, Set, Function, ObjFun, Constrs, Constr, Symbol +} + #[cfg(not(stop_on_success))] #[derive(Debug, Clone)] pub struct Data { + pub term_type: TermType, pub domain: Option, pub is_constant: bool, } @@ -66,6 +72,12 @@ impl Analysis for Meta { type Data = Data; fn merge(&mut self, to: &mut Self::Data, from: Self::Data) -> DidMerge { + // Term type should never be different. + let from_term_type_diff = to.term_type != from.term_type; + if from_term_type_diff { + panic!("Unexpected term type difference."); + } + let d_before_o = to.domain.clone(); match (to.domain.clone(), from.domain.clone()) { (None, Some(_)) => { to.domain = from.domain.clone(); } @@ -113,11 +125,30 @@ impl Analysis for Meta { let domains_map = egraph.analysis.domains.clone(); - + + let mut term_type = TermType::Function; let mut domain = None; let mut is_constant = false; match enode { + Optimization::Prob(_) => { + term_type = TermType::Problem; + } + Optimization::ObjFun(_) => { + term_type = TermType::ObjFun; + } + Optimization::Constr(_) => { + term_type = TermType::Constr; + } + Optimization::Constrs(_) => { + term_type = TermType::Constrs; + } + Optimization::Eq(_) => { + term_type = TermType::Set; + } + Optimization::Le(_) => { + term_type = TermType::Set; + } Optimization::Neg(a) => { domain = domain::option_neg(get_domain(a)); is_constant = get_is_constant(a); @@ -225,21 +256,23 @@ impl Analysis for Meta { domain = Some(Domain::make_singleton(std::f64::consts::PI)); is_constant = true; } - Optimization::Symbol(_) => {} + Optimization::Symbol(_) => { + term_type = TermType::Symbol; + } Optimization::Constant(f) => { domain = Some(Domain::make_singleton((*f).into_inner())); is_constant = true; } - _ => {} } - Data { domain, is_constant } + Data { term_type, domain, is_constant } } } #[cfg(stop_on_success)] #[derive(Debug, Clone)] pub struct DataWithCost { + pub term_type: TermType, pub domain: Option, pub is_constant: bool, pub curvature: Curvature, @@ -253,6 +286,12 @@ impl Analysis for Meta { type Data = DataWithCost; fn merge(&mut self, to: &mut Self::Data, from: Self::Data) -> DidMerge { + // Term type should never be different. + let from_term_type_diff = to.term_type != from.term_type; + if from_term_type_diff { + panic!("Unexpected term type difference."); + } + let d_before_o = to.domain.clone(); match (to.domain.clone(), from.domain.clone()) { (None, Some(_)) => { to.domain = from.domain.clone(); } @@ -337,6 +376,7 @@ impl Analysis for Meta { let domains_map = egraph.analysis.domains.clone(); + let mut term_type = TermType::Function; let mut domain = None; let mut is_constant = false; let mut curvature = Curvature::Unknown; @@ -346,6 +386,7 @@ impl Analysis for Meta { match enode { Optimization::Prob([a, b]) => { + term_type = TermType::Problem; curvature = if get_curvature(a) >= get_curvature(b) { get_curvature(a) @@ -360,6 +401,7 @@ impl Analysis for Meta { term_size = 1 + get_term_size(a) + get_term_size(b); } Optimization::ObjFun(a) => { + term_type = TermType::ObjFun; // It cannot be concave, because of mapping functions. curvature = if get_curvature(a) <= Curvature::Convex { @@ -372,6 +414,7 @@ impl Analysis for Meta { term_size = 1 + get_term_size(a); } Optimization::Constr([h, c]) => { + term_type = TermType::Constr; // It cannot be concave, because the notion of concavity at the Prop (or set) level // is not well-defined. curvature = @@ -385,6 +428,7 @@ impl Analysis for Meta { term_size = 1 + get_term_size(c); } Optimization::Constrs(a) => { + term_type = TermType::Constrs; curvature = Curvature::Constant; term_size = 0; num_vars = 0; @@ -400,6 +444,7 @@ impl Analysis for Meta { best = format!("(constrs {})", constrs_s_l.join(" ")).parse().unwrap(); } Optimization::Eq([a, b]) => { + term_type = TermType::Set; if get_curvature(a) <= Curvature::Affine && get_curvature(b) <= Curvature::Affine { curvature = Curvature::Affine } @@ -408,6 +453,7 @@ impl Analysis for Meta { term_size = 1 + get_term_size(a) + get_term_size(b); } Optimization::Le([a, b]) => { + term_type = TermType::Set; curvature = curvature::of_le(get_curvature(a), get_curvature(b)); best = format!("(le {} {})", get_best(a), get_best(b)).parse().unwrap(); num_vars = get_num_vars(a) + get_num_vars(b); @@ -704,7 +750,14 @@ impl Analysis for Meta { } } - DataWithCost { domain, is_constant, curvature, best, num_vars, term_size } + DataWithCost { term_type, domain, is_constant, curvature, best, num_vars, term_size } + } +} + +pub fn is_real_expr(var: &str) -> impl Fn(&mut EGraph, Id, &Subst) -> bool { + let var = var.parse().unwrap(); + move |egraph, _, subst| { + return egraph[subst[var]].data.term_type == TermType::Function; } } diff --git a/egg-pre-dcp/src/rules.rs b/egg-pre-dcp/src/rules.rs index 680c0b6f..a0282e8f 100644 --- a/egg-pre-dcp/src/rules.rs +++ b/egg-pre-dcp/src/rules.rs @@ -3,6 +3,7 @@ use egg::{rewrite as rw, *}; use crate::optimization; use optimization::Optimization as Optimization; use optimization::Meta as Meta; +use optimization::is_real_expr as is_real_expr; use optimization::is_gt_zero as is_gt_zero; use optimization::is_ge_zero as is_ge_zero; use optimization::is_le_zero as is_le_zero; @@ -77,7 +78,8 @@ pub fn rules() -> Vec> { vec![ rw!("one_mul"; "(mul 1 ?a)" => "?a"), - rw!("one_mul-rev"; "?a" => "(mul 1 ?a)"), + rw!("one_mul-rev"; "?a" => "(mul 1 ?a)" if is_real_expr("?a")), + rw!("mul_zero"; "(mul 0 ?a)" => "0"), @@ -201,7 +203,8 @@ pub fn rules() -> Vec> { vec![ rw!("pow_half_two"; "(pow (pow ?a 0.5) 2)" => "?a" if is_ge_zero("?a")), - rw!("pow_half_two-rev"; "?a" => "(pow (pow ?a 0.5) 2)" if is_ge_zero("?a")), + rw!("pow_half_two-rev"; "?a" => "(pow (pow ?a 0.5) 2)" + if is_real_expr("?a") if is_ge_zero("?a")), rw!("binomial_two"; "(pow (add ?a ?b) 2)" => "(add (pow ?a 2) (add (mul 2 (mul ?a ?b)) (pow ?b 2)))"),