Skip to content

Commit

Permalink
feat: log max rule applied and term types (#29)
Browse files Browse the repository at this point in the history
* feat: log iteration data

* feat: get max rule applied per iteration

* feat: keep track of term type

* fix: term type difference check

* feat: also take into account term type `Symbol`
  • Loading branch information
ramonfmir committed Apr 11, 2024
1 parent f301364 commit cfd34bb
Show file tree
Hide file tree
Showing 3 changed files with 80 additions and 10 deletions.
20 changes: 17 additions & 3 deletions egg-pre-dcp/src/extract.rs
Original file line number Diff line number Diff line change
Expand Up @@ -294,14 +294,28 @@ pub fn get_steps_from_string_maybe_node_limit(

// Iterations data.
let iterations = runner.iterations;
let num_of_iterations = iterations.len();
let num_of_iterations = iterations.len() - 1;
println!("Number of iterations: {:?}.", num_of_iterations);

let mut num_rules_applied = 0;
let mut num_iter = 0;
for iteration in iterations {
for (_, count) in iteration.applied.iter() {
let mut max_count = 0;
let mut max_rule = Symbol::from("");
for (name, count) in iteration.applied.iter() {
num_rules_applied += count;
if count.clone() > max_count {
max_count = count.clone();
max_rule = name.clone();
}
}
println!("--- Iteration {:?} data (cumulative) ---", num_iter);
println!("Rewrites applied: {:?}", num_rules_applied);
println!("E-nodes at start: {:?}", iteration.egraph_nodes);
println!("Max rule applied: {:?} (count: {:?}).", max_rule, max_count);
num_iter += 1;
}
println!("Number of iterations: {:?}.", num_of_iterations);
println!("---");
println!("Number of rules applied: {:?}.", num_rules_applied);
}
} else {
Expand Down
63 changes: 58 additions & 5 deletions egg-pre-dcp/src/optimization.rs
Original file line number Diff line number Diff line change
Expand Up @@ -54,9 +54,15 @@ pub struct Meta {
pub domains : HashMap<String, Domain>,
}

#[derive(Debug, Clone, PartialEq, Eq)]
pub enum TermType {
Problem, Set, Function, ObjFun, Constrs, Constr, Symbol
}

#[cfg(not(stop_on_success))]
#[derive(Debug, Clone)]
pub struct Data {
pub term_type: TermType,
pub domain: Option<Domain>,
pub is_constant: bool,
}
Expand All @@ -66,6 +72,12 @@ impl Analysis<Optimization> for Meta {
type Data = Data;

fn merge(&mut self, to: &mut Self::Data, from: Self::Data) -> DidMerge {
// Term type should never be different.
let from_term_type_diff = to.term_type != from.term_type;
if from_term_type_diff {
panic!("Unexpected term type difference.");
}

let d_before_o = to.domain.clone();
match (to.domain.clone(), from.domain.clone()) {
(None, Some(_)) => { to.domain = from.domain.clone(); }
Expand Down Expand Up @@ -113,11 +125,30 @@ impl Analysis<Optimization> for Meta {

let domains_map =
egraph.analysis.domains.clone();


let mut term_type = TermType::Function;
let mut domain = None;
let mut is_constant = false;

match enode {
Optimization::Prob(_) => {
term_type = TermType::Problem;
}
Optimization::ObjFun(_) => {
term_type = TermType::ObjFun;
}
Optimization::Constr(_) => {
term_type = TermType::Constr;
}
Optimization::Constrs(_) => {
term_type = TermType::Constrs;
}
Optimization::Eq(_) => {
term_type = TermType::Set;
}
Optimization::Le(_) => {
term_type = TermType::Set;
}
Optimization::Neg(a) => {
domain = domain::option_neg(get_domain(a));
is_constant = get_is_constant(a);
Expand Down Expand Up @@ -225,21 +256,23 @@ impl Analysis<Optimization> for Meta {
domain = Some(Domain::make_singleton(std::f64::consts::PI));
is_constant = true;
}
Optimization::Symbol(_) => {}
Optimization::Symbol(_) => {
term_type = TermType::Symbol;
}
Optimization::Constant(f) => {
domain = Some(Domain::make_singleton((*f).into_inner()));
is_constant = true;
}
_ => {}
}

Data { domain, is_constant }
Data { term_type, domain, is_constant }
}
}

#[cfg(stop_on_success)]
#[derive(Debug, Clone)]
pub struct DataWithCost {
pub term_type: TermType,
pub domain: Option<Domain>,
pub is_constant: bool,
pub curvature: Curvature,
Expand All @@ -253,6 +286,12 @@ impl Analysis<Optimization> for Meta {
type Data = DataWithCost;

fn merge(&mut self, to: &mut Self::Data, from: Self::Data) -> DidMerge {
// Term type should never be different.
let from_term_type_diff = to.term_type != from.term_type;
if from_term_type_diff {
panic!("Unexpected term type difference.");
}

let d_before_o = to.domain.clone();
match (to.domain.clone(), from.domain.clone()) {
(None, Some(_)) => { to.domain = from.domain.clone(); }
Expand Down Expand Up @@ -337,6 +376,7 @@ impl Analysis<Optimization> for Meta {
let domains_map =
egraph.analysis.domains.clone();

let mut term_type = TermType::Function;
let mut domain = None;
let mut is_constant = false;
let mut curvature = Curvature::Unknown;
Expand All @@ -346,6 +386,7 @@ impl Analysis<Optimization> for Meta {

match enode {
Optimization::Prob([a, b]) => {
term_type = TermType::Problem;
curvature =
if get_curvature(a) >= get_curvature(b) {
get_curvature(a)
Expand All @@ -360,6 +401,7 @@ impl Analysis<Optimization> for Meta {
term_size = 1 + get_term_size(a) + get_term_size(b);
}
Optimization::ObjFun(a) => {
term_type = TermType::ObjFun;
// It cannot be concave, because of mapping functions.
curvature =
if get_curvature(a) <= Curvature::Convex {
Expand All @@ -372,6 +414,7 @@ impl Analysis<Optimization> for Meta {
term_size = 1 + get_term_size(a);
}
Optimization::Constr([h, c]) => {
term_type = TermType::Constr;
// It cannot be concave, because the notion of concavity at the Prop (or set) level
// is not well-defined.
curvature =
Expand All @@ -385,6 +428,7 @@ impl Analysis<Optimization> for Meta {
term_size = 1 + get_term_size(c);
}
Optimization::Constrs(a) => {
term_type = TermType::Constrs;
curvature = Curvature::Constant;
term_size = 0;
num_vars = 0;
Expand All @@ -400,6 +444,7 @@ impl Analysis<Optimization> for Meta {
best = format!("(constrs {})", constrs_s_l.join(" ")).parse().unwrap();
}
Optimization::Eq([a, b]) => {
term_type = TermType::Set;
if get_curvature(a) <= Curvature::Affine && get_curvature(b) <= Curvature::Affine {
curvature = Curvature::Affine
}
Expand All @@ -408,6 +453,7 @@ impl Analysis<Optimization> for Meta {
term_size = 1 + get_term_size(a) + get_term_size(b);
}
Optimization::Le([a, b]) => {
term_type = TermType::Set;
curvature = curvature::of_le(get_curvature(a), get_curvature(b));
best = format!("(le {} {})", get_best(a), get_best(b)).parse().unwrap();
num_vars = get_num_vars(a) + get_num_vars(b);
Expand Down Expand Up @@ -704,7 +750,14 @@ impl Analysis<Optimization> for Meta {
}
}

DataWithCost { domain, is_constant, curvature, best, num_vars, term_size }
DataWithCost { term_type, domain, is_constant, curvature, best, num_vars, term_size }
}
}

pub fn is_real_expr(var: &str) -> impl Fn(&mut EGraph, Id, &Subst) -> bool {
let var = var.parse().unwrap();
move |egraph, _, subst| {
return egraph[subst[var]].data.term_type == TermType::Function;
}
}

Expand Down
7 changes: 5 additions & 2 deletions egg-pre-dcp/src/rules.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ use egg::{rewrite as rw, *};
use crate::optimization;
use optimization::Optimization as Optimization;
use optimization::Meta as Meta;
use optimization::is_real_expr as is_real_expr;
use optimization::is_gt_zero as is_gt_zero;
use optimization::is_ge_zero as is_ge_zero;
use optimization::is_le_zero as is_le_zero;
Expand Down Expand Up @@ -77,7 +78,8 @@ pub fn rules() -> Vec<Rewrite<Optimization, Meta>> { vec![

rw!("one_mul"; "(mul 1 ?a)" => "?a"),

rw!("one_mul-rev"; "?a" => "(mul 1 ?a)"),
rw!("one_mul-rev"; "?a" => "(mul 1 ?a)" if is_real_expr("?a")),


rw!("mul_zero"; "(mul 0 ?a)" => "0"),

Expand Down Expand Up @@ -201,7 +203,8 @@ pub fn rules() -> Vec<Rewrite<Optimization, Meta>> { vec![

rw!("pow_half_two"; "(pow (pow ?a 0.5) 2)" => "?a" if is_ge_zero("?a")),

rw!("pow_half_two-rev"; "?a" => "(pow (pow ?a 0.5) 2)" if is_ge_zero("?a")),
rw!("pow_half_two-rev"; "?a" => "(pow (pow ?a 0.5) 2)"
if is_real_expr("?a") if is_ge_zero("?a")),

rw!("binomial_two"; "(pow (add ?a ?b) 2)" => "(add (pow ?a 2) (add (mul 2 (mul ?a ?b)) (pow ?b 2)))"),

Expand Down

0 comments on commit cfd34bb

Please sign in to comment.