diff --git a/CvxLean/Tactic/PreDCP/Egg/EggTypes.lean b/CvxLean/Tactic/PreDCP/Egg/EggTypes.lean index 4275a8ea..d2a7b1e7 100644 --- a/CvxLean/Tactic/PreDCP/Egg/EggTypes.lean +++ b/CvxLean/Tactic/PreDCP/Egg/EggTypes.lean @@ -54,6 +54,7 @@ def EggOCTreeExtended := EggOCTree × Array EggDomainIdentified /-- A request consists of an `EggMinimization` and a list of domains per variable (or parameter). -/ structure EggRequest where + probName : String domains : List (String × EggDomain) target : EggMinimization diff --git a/CvxLean/Tactic/PreDCP/Egg/Runner.lean b/CvxLean/Tactic/PreDCP/Egg/Runner.lean index 4f286479..5acdb77a 100644 --- a/CvxLean/Tactic/PreDCP/Egg/Runner.lean +++ b/CvxLean/Tactic/PreDCP/Egg/Runner.lean @@ -67,6 +67,7 @@ NOTE: Tuples are lists of two elements. -/ def EggRequest.toJson (e : EggRequest) : String := "{" ++ surroundQuotes "request" ++ " : " ++ surroundQuotes "PerformRewrite" ++ ", " ++ + surroundQuotes "prob_name" ++ " : " ++ surroundQuotes e.probName ++ ", " ++ surroundQuotes "domains" ++ " : " ++ "[" ++ (", ".intercalate <| e.domains.map (fun domain => @@ -118,7 +119,8 @@ def runEggRequestRaw (requestJson : String) : MetaM String := do return stdout /-- Read `egg`'s output and trun it into an array of `EggRewrite`s. -/ -def parseEggResponse (responseString : String) : MetaM (Array EggRewrite) := do +def parseEggResponse (responseString : String) : + MetaM (HashMap String (Array EggRewrite)) := do dbg_trace s!"Egg response: {responseString}" let outJson : Json ← match Json.parse responseString with | Except.error e => throwError (s!"error calling `egg`, JSON parsing error ({e}).") @@ -133,29 +135,33 @@ def parseEggResponse (responseString : String) : MetaM (Array EggRewrite) := do else let steps ← liftExcept <| outJson.getObjVal? "steps" - let steps ← liftExcept <| Json.getArr? steps - - let res := steps.map fun step => - let rewriteName := (step.getObjValD "rewrite_name").getStr! - let direction := match (step.getObjValD "direction").getStr! with - | "Forward" => EggRewriteDirection.Forward - | "Backward" => EggRewriteDirection.Backward - | _ => panic! "Unexpected rewrite direction." - let location := (step.getObjValD "location").getStr! - let subexprFrom := (step.getObjValD "subexpr_from").getStr! - let subexprTo := (step.getObjValD "subexpr_to").getStr! - let expectedTerm := (step.getObjValD "expected_term").getStr! - { rewriteName := rewriteName, - direction := direction, - location := location, - subexprFrom := subexprFrom, - subexprTo := subexprTo, - expectedTerm := expectedTerm } + let steps ← liftExcept <| Json.getObj? steps + + let mut res := HashMap.empty + for ⟨componentName, componentSteps⟩ in steps.toArray do + let componentSteps ← liftExcept <| Json.getArr? componentSteps + let componentStepsParsed : Array EggRewrite := componentSteps.map fun step => + let rewriteName := (step.getObjValD "rewrite_name").getStr! + let direction := match (step.getObjValD "direction").getStr! with + | "Forward" => EggRewriteDirection.Forward + | "Backward" => EggRewriteDirection.Backward + | _ => panic! "Unexpected rewrite direction." + let location := (step.getObjValD "location").getStr! + let subexprFrom := (step.getObjValD "subexpr_from").getStr! + let subexprTo := (step.getObjValD "subexpr_to").getStr! + let expectedTerm := (step.getObjValD "expected_term").getStr! + { rewriteName := rewriteName, + direction := direction, + location := location, + subexprFrom := subexprFrom, + subexprTo := subexprTo, + expectedTerm := expectedTerm } + res := res.insert componentName componentStepsParsed return res /-- Run request to `egg` and parse the output to get an array of rewrites, if successful. -/ -def runEggRequest (request : EggRequest) : MetaM (Array EggRewrite) := +def runEggRequest (request : EggRequest) : MetaM (HashMap String (Array EggRewrite)) := dbg_trace s!"Running egg request: {request.toJson}" runEggRequestRaw request.toJson >>= parseEggResponse diff --git a/CvxLean/Tactic/PreDCP/PreDCP.lean b/CvxLean/Tactic/PreDCP/PreDCP.lean index 713abdcf..1c258889 100644 --- a/CvxLean/Tactic/PreDCP/PreDCP.lean +++ b/CvxLean/Tactic/PreDCP/PreDCP.lean @@ -207,30 +207,37 @@ def preDCPBuilder : EquivalenceBuilder Unit := fun eqvExpr g => g.withContext do -- Prepare `egg` request. let eggMinimization := EggMinimization.ofEggOCTree gStr let eggRequest : EggRequest := - { domains := (varDomainConstrs ++ paramDomains).data, + { probName := "lean_prob", + domains := (varDomainConstrs ++ paramDomains).data, target := eggMinimization } try -- Call `egg` (time it for evaluation). let bef ← BaseIO.toIO IO.monoMsNow - let steps ← runEggRequest eggRequest + let stepsByComponent ← runEggRequest eggRequest let aft ← BaseIO.toIO IO.monoMsNow let diff := aft - bef dbg_trace s!"Egg time: {diff} ms." - dbg_trace s!"Number of steps: {steps.size}." dbg_trace s!"Term size: {probSize}." dbg_trace s!"Term JSON: {eggMinimization.toJson}." -- Apply steps. + let mut stepsCount := 0 let mut g := g - for step in steps do - let gs ← Tactic.run g <| (evalStep step varsNames paramsNames paramsDecls tagsMap).toTactic - if gs.length != 1 then - trace[CvxLean.debug] "Remaining goals: {gs}." - throwPreDCPError "failed to rewrite {step.rewriteName} ({gs.length} goals remaining)." - else - trace[CvxLean.debug] "Rewrote {step.rewriteName}." - g := gs[0]! + for (_componentName, steps) in stepsByComponent do + -- TODO: Since rewrites are now split by component, we could apply the corresponding + -- congruence rule just once. + for step in steps do + stepsCount := stepsCount + 1 + let gs ← Tactic.run g <| (evalStep step varsNames paramsNames paramsDecls tagsMap).toTactic + if gs.length != 1 then + trace[CvxLean.debug] "Remaining goals: {gs}." + throwPreDCPError "failed to rewrite {step.rewriteName} ({gs.length} goals remaining)." + else + trace[CvxLean.debug] "Rewrote {step.rewriteName}." + g := gs[0]! + + dbg_trace s!"Number of steps: {stepsCount}." let gsFinal ← evalTacticAt (← `(tactic| equivalence_rfl)) g if gsFinal.length != 0 then diff --git a/egg-pre-dcp/src/curvature.rs b/egg-pre-dcp/src/curvature.rs index 617949f8..c2dd456f 100644 --- a/egg-pre-dcp/src/curvature.rs +++ b/egg-pre-dcp/src/curvature.rs @@ -1,5 +1,6 @@ use core::cmp::Ordering; use std::fmt; +use serde::Serialize; use crate::domain; use domain::Domain as Domain; @@ -78,6 +79,12 @@ impl fmt::Display for Curvature { } } +impl Serialize for Curvature { + fn serialize(&self, serializer: S) -> Result where S: serde::Serializer { + serializer.serialize_str(&self.to_string()) + } +} + pub fn join(c1: Curvature, c2: Curvature) -> Curvature { if c1 <= c2 { c2 } else if c2 <= c1 { c1 } else { Unknown } } diff --git a/egg-pre-dcp/src/explain_util.rs b/egg-pre-dcp/src/explain_util.rs index 0517d05a..60cf4cfe 100644 --- a/egg-pre-dcp/src/explain_util.rs +++ b/egg-pre-dcp/src/explain_util.rs @@ -1,6 +1,6 @@ use egg::{*}; use std::collections::HashMap; -use serde::Serialize; +use serde::{Deserialize, Serialize}; use crate::optimization; use optimization::Optimization as Optimization; @@ -8,6 +8,86 @@ use optimization::Meta as Meta; pub type Rewrite = egg::Rewrite; +// Representation of an optimization problem by its components. + +#[derive(Deserialize, Debug)] +pub struct Minimization { + pub obj_fun : String, + pub constrs : Vec<(String, String)>, +} + +impl ToString for Minimization { + fn to_string(&self) -> String { + let obj_fun_s: String = format!("(objFun {})", self.obj_fun); + let constrs_s_l : Vec = + self.constrs.iter().map( + |(h, c)| format!("(constr {} {})", h, c)).collect(); + let constr_s = format!("(constrs {})", constrs_s_l.join(" ")); + return format!("(prob {} {})", obj_fun_s, constr_s); + } +} + +impl Serialize for Minimization { + fn serialize(&self, serializer: S) -> Result where S: serde::Serializer { + serializer.serialize_str(&self.to_string()) + } +} + +pub struct MinimizationIter { + min_iter : Vec<(String, String)>, +} + +impl Minimization { + pub fn iter(&self) -> MinimizationIter { + let obj_fun_s = format!("(objFun {})", self.obj_fun); + let mut min_iter = vec![("objFun".to_string(), obj_fun_s.clone())]; + min_iter.append(&mut self.constrs.clone()); + MinimizationIter { min_iter } + } +} + +impl Iterator for MinimizationIter { + type Item = (String, String); + + fn next(&mut self) -> Option { + if self.min_iter.is_empty() { + None + } else { + Some(self.min_iter.remove(0)) + } + } +} + +// Since sometimes we want to rewrite a full optimization problem, and sometimes just an expression. +#[derive(Deserialize, Debug)] +pub enum MinimizationOrExpr { + Min(Minimization), + Expr(String), +} + +impl Serialize for MinimizationOrExpr { + fn serialize(&self, serializer: S) -> Result where S: serde::Serializer { + match self { + MinimizationOrExpr::Min(min) => min.serialize(serializer), + MinimizationOrExpr::Expr(expr) => expr.serialize(serializer), + } + } +} + +impl MinimizationOrExpr { + pub fn iter(&self) -> MinimizationIter { + match self { + MinimizationOrExpr::Min(min) => min.iter(), + MinimizationOrExpr::Expr(expr) => { + let min_iter = vec![("expr".to_string(), expr.clone())]; + MinimizationIter { min_iter } + } + } + } +} + +// Reading flat terms. + #[derive(Clone, Copy, Serialize, Debug)] pub enum Direction { Forward, diff --git a/egg-pre-dcp/src/extract.rs b/egg-pre-dcp/src/extract.rs index 1ac2e200..0014adbd 100644 --- a/egg-pre-dcp/src/extract.rs +++ b/egg-pre-dcp/src/extract.rs @@ -5,7 +5,7 @@ use std::time::Duration; use std::collections::HashMap; use std::vec; use std::time::Instant; -use serde::{Deserialize, Serialize}; +use serde::Serialize; use crate::domain; use domain::Domain as Domain; @@ -26,10 +26,19 @@ use crate::cost; use cost::DCPCost as DCPCost; use crate::explain_util; +use explain_util::Minimization as Minimization; +use explain_util::MinimizationOrExpr as MinimizationOrExpr; use explain_util::Direction as Direction; use explain_util::get_rewrite_name_and_direction as get_rewrite_name_and_direction; use explain_util::expected_expr_with_hole as expected_expr_with_hole; +use crate::report; +use report::IterationReport as IterationReport; +use report::ComponentReport as ComponentReport; +use report::Report as Report; + +// A step in the explanation, and methods from extracting them from pairs of "flat terms". + #[derive(Serialize, Debug)] pub struct Step { rewrite_name : String, @@ -132,44 +141,38 @@ fn get_step(current: &FlatTerm, next: &FlatTerm) -> return None; } -#[derive(Deserialize, Debug)] -pub struct Minimization { - pub obj_fun : String, - pub constrs : Vec<(String, String)>, -} +// Return the rewrite steps if egg successfully found a chain of rewrites to transform the term into +// DCP form. Return `None` if it didn't. -impl ToString for Minimization { - fn to_string(&self) -> String { - let obj_fun_s: String = format!("(objFun {})", self.obj_fun); - let constrs_s_l : Vec = - self.constrs.iter().map(|(h, c)| format!("(constr {} {})", h, c)).collect(); - let constr_s = format!("(constrs {})", constrs_s_l.join(" ")); - return format!("(prob {} {})", obj_fun_s, constr_s); - } -} - -// Return the rewrite steps if egg successfully found a chain of rewrites to -// transform the term into DCP form. Return `None` if it didn't. #[allow(unused)] -pub fn get_steps_maybe_node_limit(prob: Minimization, domains_vec: Vec<(String, Domain)>, debug: bool, node_limit: Option) -> Option> { - get_steps_from_string_maybe_node_limit(&prob.to_string(), domains_vec, debug, node_limit) +pub fn get_steps_maybe_node_limit( + prob_name : &str, + prob: Minimization, + domains_vec: Vec<(String, Domain)>, + debug: bool, + node_limit: Option) -> Option>> { + get_steps_from_string_maybe_node_limit( + prob_name, MinimizationOrExpr::Min(prob), domains_vec, debug, node_limit) } -pub fn get_steps(prob: Minimization, domains_vec: Vec<(String, Domain)>, debug: bool) -> Option> { - get_steps_from_string(&prob.to_string(), domains_vec, debug) -} - -pub fn get_steps_from_string(prob_s: &str, domains_vec: Vec<(String, Domain)>, debug: bool) -> Option> { - get_steps_from_string_maybe_node_limit(prob_s, domains_vec, debug, None) +pub fn get_steps( + prob_name : &str, + prob: Minimization, + domains_vec: Vec<(String, Domain)>, + debug: bool) -> Option>> { + get_steps_from_string_maybe_node_limit( + prob_name, MinimizationOrExpr::Min(prob), domains_vec, debug, None) } pub fn get_steps_from_string_maybe_node_limit( - prob_s: &str, - domains_vec: Vec<(String, Domain)>, - debug: bool, - node_limit: Option) -> Option> { - let starting_time: Instant = Instant::now(); - let expr: RecExpr = prob_s.parse().unwrap(); + prob_name : &str, + prob: MinimizationOrExpr, + domains_vec: Vec<(String, Domain)>, + debug: bool, + node_limit: Option) -> Option>> { + let starting_time = Instant::now(); + let mut report = Report::new(prob_name.to_string()); + let mut res = HashMap::new(); // Process domains, intersecting domains assigned to the same variable. let domains_len = domains_vec.len(); @@ -192,114 +195,122 @@ pub fn get_steps_from_string_maybe_node_limit( return None; } } - - // Choose the specified node limit, or select the default ones (for `stop_on_success`, we set - // a large limit; note that if a DCP term is found it will not be reached). - let node_limits = - if let Some(n) = node_limit { - vec![n] - } else { - if cfg!(stop_on_success) { - vec![100000] - } else { - vec![2500, 5000, 10000, 20000, 40000, 80000] - } - }; - for node_limit in node_limits { - let analysis = Meta { - domains : domains.clone() - }; - - // Set up the runner with the given expression, analysis and limits. - let iter_limit = node_limit / 250; - let time_limit = (node_limit / 500).try_into().unwrap(); - let mut runner: Runner = - Runner::new(analysis) - .with_explanations_enabled() - .with_explanation_length_optimization() - .with_node_limit(node_limit) - .with_iter_limit(iter_limit) - .with_time_limit(Duration::from_secs(time_limit)) - .with_expr(&expr); - - #[cfg(stop_on_success)] - { - runner = - runner - .with_hook(|runner| { - let data = runner.egraph[runner.roots[0]].data.clone(); - if data.curvature <= Curvature::Convex { - return Err("DCP term found.".to_string()); - } - return Ok(()); - }) - .run(&rules()) - } - #[cfg(not(stop_on_success))] - { - runner = runner.run(&rules()) - } - - if debug { - println!("Creating graph with {:?} nodes.", runner.egraph.total_number_of_nodes()); - let dot_str = runner.egraph.dot().to_string(); - fs::write("test.dot", dot_str).expect(""); - } - - let root = runner.roots[0]; - - // Extract the best term and best cost. This is obtained directly from the e-class - // analysis in the `stop_on_success` case, and by running the extractor otherwise. - let best_curvature; - let best_num_vars; - let best_term_size; - let best; - #[cfg(stop_on_success)] - { - let result_data = runner.egraph[root].data.clone(); - best = result_data.best; - best_curvature = result_data.curvature; - best_num_vars = result_data.num_vars; - best_term_size = result_data.term_size; - } - #[cfg(not(stop_on_success))] - { - let cost_func = DCPCost { egraph: &runner.egraph }; - let extractor = Extractor::new(&runner.egraph, cost_func); - let (best_cost_found, best_found) = extractor.find_best(root); - best = best_found; - best_curvature = best_cost_found.curvature; - best_num_vars = best_cost_found.num_vars; - best_term_size = best_cost_found.term_size; - } + // NOTE: each domain constraint is an expression with 3 nodes, e.g. `0 <= x`. + report.set_initial_term_size(3 * (domains_len as u32)); - let curvature = best_curvature; - let num_vars = best_num_vars; + // Go through the components. + // TODO: Opportunity to parallelize this loop (e.g., using rayon). + for (component_name, component_s) in prob.iter() { + let component_starting_time = Instant::now(); - // Note: each domain constraint is an expression with 3 nodes, e.g. `0 <= x`. - let term_size = best_term_size + 3 * (domains_len as u32); + // Expression to rewrite. + let expr: RecExpr = component_s.parse().unwrap(); - if curvature <= Curvature::Convex { + // Choose the specified node limit, or select the default ones (for `stop_on_success`, we + // set a large limit; note that if a DCP term is found it will not be reached). + let node_limits = + if let Some(n) = node_limit { + vec![n] + } else { + if cfg!(stop_on_success) { + vec![100000] + } else { + vec![2500, 5000, 10000, 20000, 40000, 80000] + } + }; + let mut success = false; + + for node_limit in node_limits { + let mut component_report = ComponentReport::new(component_name.clone()); + + let analysis = Meta { + domains : domains.clone() + }; + + // Set up the runner with the given expression, analysis and limits. + let iter_limit = node_limit / 250; + let time_limit = (node_limit / 500).try_into().unwrap(); + let mut runner: Runner = + Runner::new(analysis) + .with_explanations_enabled() + .with_explanation_length_optimization() + .with_node_limit(node_limit) + .with_iter_limit(iter_limit) + .with_time_limit(Duration::from_secs(time_limit)) + .with_expr(&expr); + + #[cfg(stop_on_success)] + { + runner = + runner + .with_hook(|runner| { + let data = runner.egraph[runner.roots[0]].data.clone(); + if data.curvature <= Curvature::Convex { + return Err("DCP term found.".to_string()); + } + return Ok(()); + }) + .run(&rules()) + } + #[cfg(not(stop_on_success))] + { + runner = runner.run(&rules()) + } + if debug { - let total_nodes = runner.egraph.total_number_of_nodes(); - println!("Succeeded with node limit {:?} (using {:?} nodes).", node_limit, total_nodes); - println!("Best curvature: {:?}.", curvature); - println!("Best number of variables: {:?}.", num_vars); - println!("Best term size: {:?}.", term_size); - println!("Best term: {:?}.", best.to_string()); + println!("Creating graph with {:?} nodes.", runner.egraph.total_number_of_nodes()); + let dot_str = runner.egraph.dot().to_string(); + fs::write("test.dot", dot_str).expect(""); + } + + let root = runner.roots[0]; + // Extract the best term and best cost. This is obtained directly from the e-class + // analysis in the `stop_on_success` case, and by running the extractor otherwise. + let best_curvature; + let best_num_vars; + let best_term_size; + let best; + #[cfg(stop_on_success)] + { + let result_data = runner.egraph[root].data.clone(); + best = result_data.best; + best_curvature = result_data.curvature; + best_num_vars = result_data.num_vars; + best_term_size = result_data.term_size; + } + #[cfg(not(stop_on_success))] + { + let cost_func = DCPCost { egraph: &runner.egraph }; + let extractor = Extractor::new(&runner.egraph, cost_func); + let (best_cost_found, best_found) = extractor.find_best(root); + best = best_found; + best_curvature = best_cost_found.curvature; + best_num_vars = best_cost_found.num_vars; + best_term_size = best_cost_found.term_size; + } + + if best_curvature <= Curvature::Convex { + let nodes = runner.egraph.total_number_of_nodes(); let build_time = starting_time.elapsed().as_millis(); - println!("E-graph building time: {:.2?} ms.", build_time); + + component_report.set_nodes(nodes); + component_report.set_node_limit(node_limit); + component_report.set_best_curvature(best_curvature); + component_report.set_best_num_vars(best_num_vars); + component_report.set_best_term_size(best_term_size); + component_report.set_best_term(best.to_string()); + component_report.set_build_time(build_time); // Iterations data. let iterations = runner.iterations; let num_of_iterations = iterations.len() - 1; - println!("Number of iterations: {:?}.", num_of_iterations); + component_report.set_num_iterations(num_of_iterations); let mut num_rules_applied = 0; - let mut num_iter = 0; - for iteration in iterations { + for i in 0..num_of_iterations { + let iteration = &iterations[i]; let mut max_count = 0; let mut max_rule = Symbol::from(""); for (name, count) in iteration.applied.iter() { @@ -309,54 +320,73 @@ pub fn get_steps_from_string_maybe_node_limit( max_rule = name.clone(); } } - println!("--- Iteration {:?} data (cumulative) ---", num_iter); - println!("Rewrites applied: {:?}", num_rules_applied); - println!("E-nodes at start: {:?}", iteration.egraph_nodes); - println!("Max rule applied: {:?} (count: {:?}).", max_rule, max_count); - num_iter += 1; + + let iter_report = IterationReport::new( + num_rules_applied, + iteration.egraph_nodes, + max_rule.to_string(), + max_count, + ); + component_report.add_iteration_report(iter_report); } - println!("---"); - println!("Number of rules applied: {:?}.", num_rules_applied); + component_report.set_num_rules_applied(num_rules_applied); + } else { + // If term is not DCP, try with the next node limit. + continue; } - } else { - // If term is not DCP, try with the next node limit. - continue; - } - let after_build_time = Instant::now(); - - // If term is DCP, find the explanation. - let mut egraph = runner.egraph; - let mut explanation : Explanation = - egraph.explain_equivalence(&expr, &best); - let flat_explanation : &FlatExplanation = - explanation.make_flat_explanation(); - if debug { - println!("Number of steps: {}.", flat_explanation.len() - 1); - } - - let mut res = Vec::new(); - for i in 0..flat_explanation.len() - 1 { - let current = &flat_explanation[i]; - let next = &flat_explanation[i + 1]; - match get_step(current, next) { - Some(step) => { res.push(step); } - None => { - // Should not get here. - println!("Failed to extract step."); + let after_build_time = Instant::now(); + + // If term is DCP, find the explanation. + let mut egraph = runner.egraph; + let mut explanation : Explanation = + egraph.explain_equivalence(&expr, &best); + let flat_explanation : &FlatExplanation = + explanation.make_flat_explanation(); + component_report.set_steps_count(flat_explanation.len() - 1); + + let mut steps = Vec::new(); + for i in 0..flat_explanation.len() - 1 { + let current = &flat_explanation[i]; + let next = &flat_explanation[i + 1]; + match get_step(current, next) { + Some(mut step) => { + // TODO: `get_step_aux` can be simplified, since we know the location + // directly. + step.location = component_name.clone(); + steps.push(step); + } + None => { + panic!("Failed to extract step."); + } } } + + let explain_time = after_build_time.elapsed().as_millis(); + let component_time = component_starting_time.elapsed().as_millis(); + component_report.set_explain_time(explain_time); + component_report.set_component_time(component_time); + + report.add_component_report(component_report); + + res.insert(component_name.clone(), steps); + + success = true; + break; } - if debug { - let extract_time = after_build_time.elapsed().as_millis(); - let total_time = starting_time.elapsed().as_millis(); - println!("Step extraction time: {:.2?} ms.", extract_time); - println!("Total time: {:.2?} ms.", total_time); + // It failed for all node limits. + if !success { + return None; } + } + + // Craft final report. + let total_time = starting_time.elapsed().as_millis(); + report.set_total_time(total_time); - return Some(res); + if debug { + println!("{:?}", report); } - // It failed for all node limits. - return None; + return Some(res); } diff --git a/egg-pre-dcp/src/lib.rs b/egg-pre-dcp/src/lib.rs index cc060f82..584dfff3 100644 --- a/egg-pre-dcp/src/lib.rs +++ b/egg-pre-dcp/src/lib.rs @@ -11,6 +11,8 @@ pub mod cost; pub mod explain_util; +pub mod report; + pub mod extract; pub mod test_util; diff --git a/egg-pre-dcp/src/main.rs b/egg-pre-dcp/src/main.rs index 7638fb77..9bb0a873 100644 --- a/egg-pre-dcp/src/main.rs +++ b/egg-pre-dcp/src/main.rs @@ -1,4 +1,5 @@ use std::io; +use std::collections::HashMap; use serde::{Deserialize, Serialize}; #[allow(dead_code)] @@ -14,9 +15,11 @@ mod rules; mod cost; mod explain_util; +use explain_util::Minimization as Minimization; + +mod report; mod extract; -use extract::Minimization as Minimization; use extract::Step as Step; use extract::get_steps as get_steps; @@ -25,7 +28,8 @@ use extract::get_steps as get_steps; #[derive(Deserialize, Debug)] #[serde(tag = "request")] enum Request { - PerformRewrite { + PerformRewrite { + prob_name : String, domains : Vec<(String, Domain)>, target : Minimization, } @@ -34,7 +38,7 @@ enum Request { #[derive(Serialize, Debug)] #[serde(tag = "response")] enum Response { - Success { steps: Vec }, + Success { steps: HashMap> }, Error { error: String } } @@ -53,8 +57,8 @@ fn main_json() -> io::Result<()> { Ok(req) => { match req { Request::PerformRewrite - { domains, target } => - match get_steps(target, domains, false) { + { prob_name, domains, target } => + match get_steps(&prob_name, target, domains, false) { Some(steps) => Response::Success { steps }, None => Response::Error { error: format!("Could not rewrite target into DCP form.") diff --git a/egg-pre-dcp/src/report.rs b/egg-pre-dcp/src/report.rs new file mode 100644 index 00000000..be802345 --- /dev/null +++ b/egg-pre-dcp/src/report.rs @@ -0,0 +1,227 @@ +/*! +Data collected for evaluation. +!*/ + +use serde::Serialize; + +use crate::curvature; +use curvature::Curvature as Curvature; + +use crate::explain_util; +use crate::explain_util::Minimization as Minimization; +use explain_util::MinimizationOrExpr as MinimizationOrExpr; + +#[derive(Debug, Serialize)] +pub struct IterationReport { + // Number of rules applied up to this iteration. + num_rules_applied : usize, + // E-nodes at the start of the iteration. Note that the number of nodes at the last iteration is + // *not* the total number of nodes in the e-graph. Moreover, this number is not exact, as it is + // the size of the hashcons map, so it should be regarded as a conservative estimate. + num_nodes : usize, + // Rule applied the most times. + max_rule : String, + max_rule_count : usize, +} + +impl IterationReport { + pub fn new( + num_rules_applied: usize, + num_nodes: usize, + max_rule: String, + max_rule_count: usize) -> Self { + IterationReport { + num_rules_applied, + num_nodes, + max_rule, + max_rule_count, + } + } +} + +#[derive(Debug, Serialize)] +pub struct ComponentReport { + // Component name: objFun, expr, h1, h2, ... + component_name : String, + // E-nodes data. + nodes : usize, + node_limit : usize, + // Best term data. + best_curvature : Curvature, + best_num_vars : u32, + best_term_size : u32, + best_term: String, + // Time to build the e-graph. + build_time : u128, + // Iteration data. + num_iterations : usize, + iteration_reports : Vec, + // Number of rules applied. + num_rules_applied : usize, + // Number of steps in the explanation. + steps_count : usize, + // Time to get an explanation. + explain_time : u128, + // Component time. + component_time : u128, +} + +impl ComponentReport { + pub fn new(component_name: String) -> Self { + ComponentReport { + component_name, + nodes: 0, + node_limit: 0, + best_curvature: Curvature::Unknown, + best_num_vars: 0, + best_term_size: 0, + best_term: String::new(), + build_time: 0, + num_iterations: 0, + iteration_reports: Vec::new(), + num_rules_applied: 0, + steps_count: 0, + explain_time: 0, + component_time: 0, + } + } + + pub fn set_nodes(&mut self, nodes: usize) { + self.nodes = nodes; + } + + pub fn set_node_limit(&mut self, node_limit: usize) { + self.node_limit = node_limit; + } + + pub fn set_best_curvature(&mut self, best_curvature: Curvature) { + self.best_curvature = best_curvature; + } + + pub fn set_best_num_vars(&mut self, best_num_vars: u32) { + self.best_num_vars = best_num_vars; + } + + pub fn set_best_term_size(&mut self, best_term_size: u32) { + self.best_term_size = best_term_size; + } + + pub fn set_best_term(&mut self, best_term: String) { + self.best_term = best_term; + } + + pub fn set_build_time(&mut self, build_time: u128) { + self.build_time = build_time; + } + + pub fn set_num_iterations(&mut self, num_iterations: usize) { + self.num_iterations = num_iterations; + } + + pub fn add_iteration_report(&mut self, iteration_report: IterationReport) { + self.iteration_reports.push(iteration_report); + } + + pub fn set_num_rules_applied(&mut self, num_rules_applied: usize) { + self.num_rules_applied = num_rules_applied; + } + + pub fn set_steps_count(&mut self, steps_count: usize) { + self.steps_count = steps_count; + } + + pub fn set_explain_time(&mut self, explain_time: u128) { + self.explain_time = explain_time; + } + + pub fn set_component_time(&mut self, component_time: u128) { + self.component_time = component_time; + } +} + +#[derive(Debug, Serialize)] +pub struct Report { + prob_name : String, + // Main statistics. + total_time : u128, + total_nodes : usize, + total_steps : usize, + // Best term data. + best_curvature : Curvature, + best_num_vars : u32, + best_term_size : u32, + best_term : MinimizationOrExpr, + // Fine-grained components data. + component_reports : Vec, +} + +impl Report { + pub fn new(prob_name: String) -> Self { + Report { + prob_name, + total_time: 0, + total_nodes: 0, + total_steps: 0, + best_curvature: Curvature::Unknown, + best_num_vars: 0, + // Account for the "prob" node. + best_term_size: 1, + best_term: MinimizationOrExpr::Expr(String::new()), + component_reports: Vec::new(), + } + } + + pub fn set_total_time(&mut self, total_time: u128) { + self.total_time = total_time; + } + + pub fn set_initial_term_size(&mut self, term_size: u32) { + self.best_term_size = term_size; + } + + pub fn add_component_report(&mut self, component_report: ComponentReport) { + self.total_nodes += component_report.nodes; + self.total_steps += component_report.steps_count; + + // The curvatures of the components are assumed to be <= Convex. We take the maximum. + if self.best_curvature == Curvature::Unknown || + component_report.best_curvature > self.best_curvature { + self.best_curvature = component_report.best_curvature.clone(); + } + + // Number of vars and term size (+1 for the constr / objFun component) are just summed. + self.best_num_vars += component_report.best_num_vars; + self.best_term_size += 1 + component_report.best_term_size; + + // A little more work is needed to place the best term from the component / expression into + // the best term of the report. + if component_report.component_name.eq("expr") { + self.best_term = MinimizationOrExpr::Expr(component_report.best_term.clone()); + } else if component_report.component_name.eq("objFun") { + match &mut self.best_term { + MinimizationOrExpr::Expr(_) => { + let obj_fun = component_report.best_term.clone(); + let min = Minimization { obj_fun, constrs: Vec::new() }; + self.best_term = MinimizationOrExpr::Min(min); + }, + MinimizationOrExpr::Min(min) => { + min.obj_fun = component_report.best_term.clone(); + }, + } + } else { + match &mut self.best_term { + MinimizationOrExpr::Expr(_) => { + panic!("Minimization expected."); + }, + MinimizationOrExpr::Min(min) => { + let constr_name = component_report.component_name.clone(); + let constr_s = component_report.best_term.clone(); + min.constrs.push((constr_name, constr_s)); + }, + + } + } + + self.component_reports.push(component_report); + } +} diff --git a/egg-pre-dcp/src/rules.rs b/egg-pre-dcp/src/rules.rs index a0282e8f..c1d18355 100644 --- a/egg-pre-dcp/src/rules.rs +++ b/egg-pre-dcp/src/rules.rs @@ -24,9 +24,8 @@ pub fn rules() -> Vec> { vec![ /* Equality rules. */ - // NOTE: many more rules could apply here, but in our examples, equalities - // were either already affine or required applying logarithms to remove - // exponentials and make them affine. + // NOTE: many more rules could apply here, but in our examples, equalities were either already + // affine or required applying logarithms to remove exponentials and make them affine. rw!("log_eq_log"; "(eq ?a ?b)" => "(eq (log ?a) (log ?b))" if is_gt_zero("?a") if is_gt_zero("?b")), @@ -206,7 +205,8 @@ pub fn rules() -> Vec> { vec![ rw!("pow_half_two-rev"; "?a" => "(pow (pow ?a 0.5) 2)" if is_real_expr("?a") if is_ge_zero("?a")), - rw!("binomial_two"; "(pow (add ?a ?b) 2)" => "(add (pow ?a 2) (add (mul 2 (mul ?a ?b)) (pow ?b 2)))"), + rw!("binomial_two"; + "(pow (add ?a ?b) 2)" => "(add (pow ?a 2) (add (mul 2 (mul ?a ?b)) (pow ?b 2)))"), rw!("inv_eq_pow_neg_one"; "(inv ?a)" => "(pow ?a (neg 1))" if is_not_zero("?a")), diff --git a/egg-pre-dcp/src/test_util.rs b/egg-pre-dcp/src/test_util.rs index 1e73c060..d60cf176 100644 --- a/egg-pre-dcp/src/test_util.rs +++ b/egg-pre-dcp/src/test_util.rs @@ -1,8 +1,11 @@ use crate::domain; use domain::Domain as Domain; +use crate::explain_util; +use explain_util::Minimization as Minimization; +use explain_util::MinimizationOrExpr as MinimizationOrExpr; + use crate::extract; -use extract::Minimization as Minimization; use extract::get_steps as get_steps; use extract::get_steps_maybe_node_limit as get_steps_maybe_node_limit; use extract::get_steps_from_string_maybe_node_limit as get_steps_from_string_maybe_node_limit; @@ -19,7 +22,12 @@ fn make(obj: &str, constrs: Vec<&str>) -> Minimization { }; } -fn pre_dcp_check_with_domain_maybe_print(domains : Vec<(&str, Domain)>, obj: &str, constrs: Vec<&str>, print: bool) { +fn pre_dcp_check_with_domain_maybe_print( + prob_name : &str, + domains : Vec<(&str, Domain)>, + obj: &str, + constrs: Vec<&str>, + print: bool) { let prob = make(obj, constrs); let domains = domains.iter().map(|(s, d)| ((*s).to_string(), d.clone())).collect(); @@ -27,10 +35,10 @@ fn pre_dcp_check_with_domain_maybe_print(domains : Vec<(&str, Domain)>, obj: &st match std::env::var("EGG_PRE_DCP_NODE_LIMIT") { Result::Ok(v) => { let node_limit = v.parse::().unwrap(); - get_steps_maybe_node_limit(prob, domains, print, Some(node_limit)) + get_steps_maybe_node_limit(prob_name, prob, domains, print, Some(node_limit)) } Result::Err(_) => { - get_steps(prob, domains, print) + get_steps(prob_name, prob, domains, print) } }; if steps.is_none() { @@ -41,28 +49,32 @@ fn pre_dcp_check_with_domain_maybe_print(domains : Vec<(&str, Domain)>, obj: &st } } -pub fn pre_dcp_check_with_domain(domains : Vec<(&str, Domain)>, obj: &str, constrs: Vec<&str>) { - pre_dcp_check_with_domain_maybe_print(domains, obj, constrs, false) +pub fn pre_dcp_check_with_domain( + prob_name : &str, domains : Vec<(&str, Domain)>, obj: &str, constrs: Vec<&str>) { + pre_dcp_check_with_domain_maybe_print(prob_name, domains, obj, constrs, false) } -pub fn pre_dcp_check_with_domain_and_print(domains : Vec<(&str, Domain)>, obj: &str, constrs: Vec<&str>) { - pre_dcp_check_with_domain_maybe_print(domains, obj, constrs, true) +pub fn pre_dcp_check_with_domain_and_print( + prob_name : &str, domains : Vec<(&str, Domain)>, obj: &str, constrs: Vec<&str>) { + pre_dcp_check_with_domain_maybe_print(prob_name, domains, obj, constrs, true) } -pub fn pre_dcp_check(obj: &str, constrs: Vec<&str>) { - pre_dcp_check_with_domain_maybe_print(vec![], obj, constrs, false) +pub fn pre_dcp_check(prob_name : &str, obj: &str, constrs: Vec<&str>) { + pre_dcp_check_with_domain_maybe_print(prob_name, vec![], obj, constrs, false) } -pub fn pre_dcp_check_and_print(obj: &str, constrs: Vec<&str>) { - pre_dcp_check_with_domain_maybe_print(vec![], obj, constrs, true) +pub fn pre_dcp_check_and_print(prob_name : &str, obj: &str, constrs: Vec<&str>) { + pre_dcp_check_with_domain_maybe_print(prob_name, vec![], obj, constrs, true) } // Used to test out-of-context expressions. fn pre_dcp_check_expression_with_domain_maybe_print_maybe_node_limit( - domains : Vec<(&str, Domain)>, - s: &str, print: bool, - node_limit: Option) { + prob_name : &str, + domains : Vec<(&str, Domain)>, + expr_s: &str, + print: bool, + node_limit: Option) { // NOTE: Even if the node limit is passed as an argument, the environment variable is used if it // is set, hence overriding the argument. let node_limit = @@ -75,9 +87,10 @@ fn pre_dcp_check_expression_with_domain_maybe_print_maybe_node_limit( node_limit } }; + let expr = MinimizationOrExpr::Expr(expr_s.to_string()); let domains = domains.iter().map(|(s, d)| ((*s).to_string(), d.clone())).collect(); - let steps = get_steps_from_string_maybe_node_limit(s, domains, true, node_limit); + let steps = get_steps_from_string_maybe_node_limit(prob_name, expr, domains, true, node_limit); if steps.is_none() { panic!("Test failed, could not rewrite target into DCP form."); } @@ -86,34 +99,44 @@ fn pre_dcp_check_expression_with_domain_maybe_print_maybe_node_limit( } } -fn pre_dcp_check_expression_with_domain_maybe_print(domains : Vec<(&str, Domain)>, s: &str, print: bool) { - pre_dcp_check_expression_with_domain_maybe_print_maybe_node_limit(domains, s, print, None); +fn pre_dcp_check_expression_with_domain_maybe_print( + prob_name: &str, domains: Vec<(&str, Domain)>, s: &str, print: bool) { + pre_dcp_check_expression_with_domain_maybe_print_maybe_node_limit( + prob_name, domains, s, print, None); } -pub fn pre_dcp_check_expression_with_domain(domains : Vec<(&str, Domain)>,s: &str) { - pre_dcp_check_expression_with_domain_maybe_print(domains, s, false); +pub fn pre_dcp_check_expression_with_domain( + prob_name: &str, domains: Vec<(&str, Domain)>,s: &str) { + pre_dcp_check_expression_with_domain_maybe_print(prob_name, domains, s, false); } -pub fn pre_dcp_check_expression_with_domain_and_print(domains : Vec<(&str, Domain)>,s: &str) { - pre_dcp_check_expression_with_domain_maybe_print(domains, s, true); +pub fn pre_dcp_check_expression_with_domain_and_print( + prob_name: &str, domains: Vec<(&str, Domain)>,s: &str) { + pre_dcp_check_expression_with_domain_maybe_print(prob_name, domains, s, true); } -fn pre_dcp_check_expression_with_domain_and_node_limit_maybe_print(domains : Vec<(&str, Domain)>, s: &str, print: bool, node_limit: usize) { - pre_dcp_check_expression_with_domain_maybe_print_maybe_node_limit(domains, s, print, Some(node_limit)); +fn pre_dcp_check_expression_with_domain_and_node_limit_maybe_print( + prob_name: &str, domains: Vec<(&str, Domain)>, s: &str, print: bool, node_limit: usize) { + pre_dcp_check_expression_with_domain_maybe_print_maybe_node_limit( + prob_name, domains, s, print, Some(node_limit)); } -pub fn pre_dcp_check_expression_with_domain_and_node_limit(domains : Vec<(&str, Domain)>,s: &str, node_limit: usize) { - pre_dcp_check_expression_with_domain_and_node_limit_maybe_print(domains, s, false, node_limit); +pub fn pre_dcp_check_expression_with_domain_and_node_limit( + prob_name: &str, domains: Vec<(&str, Domain)>, s: &str, node_limit: usize) { + pre_dcp_check_expression_with_domain_and_node_limit_maybe_print( + prob_name, domains, s, false, node_limit); } -pub fn pre_dcp_check_expression_with_domain_and_node_limit_and_print(domains : Vec<(&str, Domain)>,s: &str, node_limit: usize) { - pre_dcp_check_expression_with_domain_and_node_limit_maybe_print(domains, s, true, node_limit); +pub fn pre_dcp_check_expression_with_domain_and_node_limit_and_print( + prob_name: &str, domains: Vec<(&str, Domain)>, s: &str, node_limit: usize) { + pre_dcp_check_expression_with_domain_and_node_limit_maybe_print( + prob_name, domains, s, true, node_limit); } -pub fn pre_dcp_check_expression(s: &str) { - pre_dcp_check_expression_with_domain_maybe_print(vec![], s, false); +pub fn pre_dcp_check_expression(prob_name: &str, s: &str) { + pre_dcp_check_expression_with_domain_maybe_print(prob_name, vec![], s, false); } -pub fn pre_dcp_check_expression_and_print(s: &str) { - pre_dcp_check_expression_with_domain_maybe_print(vec![], s, true); +pub fn pre_dcp_check_expression_and_print(prob_name: &str, s: &str) { + pre_dcp_check_expression_with_domain_maybe_print(prob_name, vec![], s, true); } diff --git a/egg-pre-dcp/tests/test_almost_dgp.rs b/egg-pre-dcp/tests/test_almost_dgp.rs index bf71a7ac..17dc95bc 100644 --- a/egg-pre-dcp/tests/test_almost_dgp.rs +++ b/egg-pre-dcp/tests/test_almost_dgp.rs @@ -10,6 +10,7 @@ use egg_pre_dcp::test_util::{*}; #[test] fn test_agp1() { pre_dcp_check_and_print( + "agp1", "(exp (var u))", vec![ "(le (sub (pow (exp (var u)) 2) (div 10123 1000)) 0)" @@ -19,6 +20,7 @@ fn test_agp1() { #[test] fn test_agp2() { pre_dcp_check_and_print( + "agp2", "(exp (var u))", vec![ "(le (sub (mul (exp (var u)) (exp (var v))) (div 2691 500)) 0)" @@ -28,6 +30,7 @@ fn test_agp2() { #[test] fn test_agp3() { pre_dcp_check_and_print( + "agp3", "(add (add (exp (var u)) (exp (var v))) (exp (var w)))", vec![ "(le 2 (exp (var u)))", @@ -40,6 +43,7 @@ fn test_agp3() { #[test] fn test_agp4() { pre_dcp_check_and_print( + "agp4", "(div 1 (mul (exp (var u)) (exp (var v))))", vec![ "(le (mul (exp (var u)) (exp (var v))) (sub (sub 2 (exp (var u))) (exp (var v))))" diff --git a/egg-pre-dcp/tests/test_cost_function.rs b/egg-pre-dcp/tests/test_cost_function.rs index 849c2db3..7d9bfff2 100644 --- a/egg-pre-dcp/tests/test_cost_function.rs +++ b/egg-pre-dcp/tests/test_cost_function.rs @@ -7,6 +7,7 @@ use egg_pre_dcp::test_util::{*}; #[test] fn test_cost_function_number_of_variable_occurences() { pre_dcp_check( + "cost_function_number_of_variable_occurences", "0", vec![ "(le (var x) (sub 1 (var x)))" @@ -16,6 +17,7 @@ fn test_cost_function_number_of_variable_occurences() { #[test] fn test_cost_function_number_of_variable_occurences_2() { pre_dcp_check( + "cost_function_number_of_variable_occurences_2", "0", vec![ "(le (add (mul 2 (var x)) (var x)) 0)" diff --git a/egg-pre-dcp/tests/test_dgp.rs b/egg-pre-dcp/tests/test_dgp.rs index 9e99a5b6..7395dd97 100644 --- a/egg-pre-dcp/tests/test_dgp.rs +++ b/egg-pre-dcp/tests/test_dgp.rs @@ -9,6 +9,7 @@ use egg_pre_dcp::test_util::{*}; #[test] fn test_gp1() { pre_dcp_check_and_print( + "gp1", "(exp (var u))", vec![ "(le (pow (exp (var u)) 2) (div 10123 1000))" @@ -18,6 +19,7 @@ fn test_gp1() { #[test] fn test_gp2() { pre_dcp_check_and_print( + "gp2", "(exp (var u))", vec![ "(le (mul (exp (var u)) (exp (var v))) (div 2691 500))" @@ -27,6 +29,7 @@ fn test_gp2() { #[test] fn test_gp3() { pre_dcp_check_and_print( + "gp3", "(exp (var u))", vec![ "(le (sqrt (add (mul (exp (var u)) (exp (var u))) (exp (var v)))) 1)" @@ -36,6 +39,7 @@ fn test_gp3() { #[test] fn test_gp4() { pre_dcp_check_and_print( + "gp4", "(div (exp (var u)) (exp (var v)))", vec![ "(le 2 (exp (var u)))", @@ -48,6 +52,7 @@ fn test_gp4() { #[test] fn test_gp5() { pre_dcp_check_and_print( + "gp5", "(div 1 (div (exp (var u)) (exp (var v))))", vec![ "(le 2 (exp (var u)))", @@ -60,6 +65,7 @@ fn test_gp5() { #[test] fn test_gp6() { pre_dcp_check_and_print( + "gp6", "(div 1 (div (exp (var u)) (exp (var v))))", vec![ "(le 2 (exp (var u)))", @@ -72,6 +78,7 @@ fn test_gp6() { #[test] fn test_gp7() { pre_dcp_check_and_print( + "gp7", "(div 1 (mul (mul (exp (var h')) (exp (var w'))) (exp (var d'))))", vec![ "(le (mul 2 (add (mul (exp (var h')) (exp (var d'))) (mul (exp (var w')) (exp (var d'))))) 100)", @@ -86,6 +93,7 @@ fn test_gp7() { #[test] fn test_gp8() { pre_dcp_check_and_print( + "gp8", "(add (add (mul (mul (div 1 (exp (var u))) (div 1 (sqrt (exp (var v))))) (div 1 (exp (var w)))) (mul (mul (div 23 10) (exp (var u))) (exp (var w)))) (mul (mul (mul 4 (exp (var u))) (exp (var v))) (exp (var w))))", vec![ "(le (add (mul (mul (div 1 3) (div 1 (pow (exp (var u)) 2))) (div 1 (pow (exp (var v)) 2))) (mul (mul (div 4 3) (sqrt (exp (var v)))) (div 1 (exp (var w))))) 1)", @@ -98,6 +106,7 @@ fn test_gp8() { #[test] fn test_gp9() { pre_dcp_check_and_print( + "gp9", "(mul (mul 2 (exp (var A))) (norm2 (exp (var w)) (exp (var h))))", vec![ "(le (mul (div (mul 10 (norm2 (exp (var w)) (exp (var h)))) 2) (exp (var h))) (mul (div 1 2) (exp (var A))))", @@ -122,6 +131,7 @@ use egg_pre_dcp::test_util::{*}; #[test] fn test_gp7_with_params() { pre_dcp_check_with_domain( + "gp7_with_params", vec![ ("Aflr", domain::pos_dom()), ("α" , domain::pos_dom()), diff --git a/egg-pre-dcp/tests/test_dqcp.rs b/egg-pre-dcp/tests/test_dqcp.rs index a689fdde..b9e228c7 100644 --- a/egg-pre-dcp/tests/test_dqcp.rs +++ b/egg-pre-dcp/tests/test_dqcp.rs @@ -13,6 +13,7 @@ use egg_pre_dcp::test_util::{*}; #[test] fn test_qcp1() { pre_dcp_check_with_domain_and_print( + "qcp1", vec![("x", domain::pos_dom())], "(var x)", vec![ @@ -24,6 +25,7 @@ fn test_qcp1() { fn test_qcp2() { let d = Domain::make_oc(domain::zero(), domain::one()); pre_dcp_check_with_domain_and_print( + "qcp2", vec![("x", d)], "(sqrt (sub (div 1 (pow (var x) 2)) 1))", vec![ @@ -36,6 +38,7 @@ fn test_qcp3() { let dx = Domain::make_singleton(12.0); let dy = Domain::make_cc(domain::make_float(0.001), domain::make_float(6.0)); pre_dcp_check_with_domain_and_print( + "qcp3", vec![("x", dx), ("y", dy)], "0", vec![ @@ -47,6 +50,7 @@ fn test_qcp3() { fn test_qcp4() { let dx = Domain::make_ci(domain::make_float(10.0)); pre_dcp_check_with_domain_and_print( + "qcp4", vec![("x", dx)], "(neg (var x))", vec![ @@ -69,6 +73,7 @@ fn test_qcp2_with_params() { let da = Domain::make_ci(domain::zero()); let db = Domain::make_io(domain::one()); pre_dcp_check_with_domain( + "qcp2_with_params", vec![("x", dx), ("a", da), ("b", db)], "(sqrt (sub (div 1 (pow (var x) 2)) 1))", vec![ diff --git a/egg-pre-dcp/tests/test_main_example.rs b/egg-pre-dcp/tests/test_main_example.rs index cda8f81d..61aa920a 100644 --- a/egg-pre-dcp/tests/test_main_example.rs +++ b/egg-pre-dcp/tests/test_main_example.rs @@ -11,6 +11,7 @@ use egg_pre_dcp::test_util::{*}; #[test] fn test_main_example() { pre_dcp_check_with_domain_and_print( + "main_example", vec![("x", domain::pos_dom())], "(var x)", vec![ diff --git a/egg-pre-dcp/tests/test_misc.rs b/egg-pre-dcp/tests/test_misc.rs index 087cf5ac..9f94bbd1 100644 --- a/egg-pre-dcp/tests/test_misc.rs +++ b/egg-pre-dcp/tests/test_misc.rs @@ -9,6 +9,7 @@ use egg_pre_dcp::test_util::{*}; #[test] fn test_norm2_with_one() { pre_dcp_check_expression_with_domain( + "norm2_with_one", vec![("x", domain::free_dom())], "(sqrt (add (pow (var x) 2) 1))"); } @@ -16,6 +17,7 @@ fn test_norm2_with_one() { #[test] fn test_sqrt_pow4() { pre_dcp_check_expression_with_domain( + "sqrt_pow4", vec![("x", domain::nonneg_dom())], "(sqrt (pow (var x) 4))"); } @@ -23,12 +25,14 @@ fn test_sqrt_pow4() { #[test] fn test_div_constant_simp() { pre_dcp_check_expression( + "div_constant_simp", "(div (div (var x) 20) (div 7 20))"); } #[test] fn test_div_constant_le_simp() { pre_dcp_check_expression_with_domain( + "div_constant_le_simp", vec![("x", domain::nonneg_dom()), ("y", domain::pos_dom())], "(le (div (qol 1 (var y)) 20) (mul (div 7 20) (sqrt (var x))))"); } diff --git a/egg-pre-dcp/tests/test_quiz.rs b/egg-pre-dcp/tests/test_quiz.rs index 33c91fce..fe88730d 100644 --- a/egg-pre-dcp/tests/test_quiz.rs +++ b/egg-pre-dcp/tests/test_quiz.rs @@ -12,6 +12,7 @@ use egg_pre_dcp::test_util::{*}; #[test] fn test_quiz1() { pre_dcp_check_expression_with_domain_and_print( + "quiz1", vec![("x", domain::pos_dom())], "(inv (inv (var x)))"); } @@ -19,6 +20,7 @@ fn test_quiz1() { #[test] fn test_quiz2() { pre_dcp_check_expression_with_domain_and_print( + "quiz2", vec![("x", domain::pos_dom()), ("y", domain::pos_dom())], "(neg (lse (log (var x)) (log (var y))))"); } @@ -26,6 +28,7 @@ fn test_quiz2() { #[test] fn test_quiz3() { pre_dcp_check_expression_with_domain_and_print( + "quiz3", vec![("x", domain::nonneg_dom())], "(pow (sqrt (var x)) 2)"); } @@ -33,6 +36,7 @@ fn test_quiz3() { #[test] fn test_quiz4() { pre_dcp_check_expression_with_domain_and_print( + "quiz4", vec![("x", domain::nonneg_dom())], "(neg (abs (sqrt (abs (var x)))))"); } @@ -40,13 +44,15 @@ fn test_quiz4() { #[test] fn test_quiz5() { pre_dcp_check_expression_with_domain_and_print( + "quiz5", vec![("x", domain::free_dom())], "(div 1 (exp (var x)))"); } #[test] fn test_quiz6() { - pre_dcp_check_expression_with_domain_and_print( + pre_dcp_check_expression_with_domain_and_print( + "quiz6", vec![("x", domain::nonneg_dom())], "(neg (log (pow (mul 364 (var x)) 2)))"); } @@ -54,6 +60,7 @@ fn test_quiz6() { #[test] fn test_quiz7() { pre_dcp_check_expression_with_domain_and_print( + "quiz7", vec![("x", domain::pos_dom())], "(pow (geo (add (var x) 2) (div 1 (var x))) 2)"); } @@ -61,6 +68,7 @@ fn test_quiz7() { #[test] fn test_quiz8() { pre_dcp_check_expression_with_domain_and_print( + "quiz8", vec![("x", domain::nonneg_dom())], "(neg (log (abs (var x))))"); } @@ -68,6 +76,7 @@ fn test_quiz8() { #[test] fn test_quiz9() { pre_dcp_check_expression_with_domain_and_print( + "quiz9", vec![("x", domain::pos_dom()), ("y", domain::pos_dom())], "(div 1 (qol (inv (var x)) (inv (var y)))))"); } @@ -75,6 +84,7 @@ fn test_quiz9() { #[test] fn test_quiz10() { pre_dcp_check_expression_and_print( + "quiz10", "(pow (log (exp (var x))) 2)"); } diff --git a/egg-pre-dcp/tests/test_rules.rs b/egg-pre-dcp/tests/test_rules.rs index 92d2c4f6..97305147 100644 --- a/egg-pre-dcp/tests/test_rules.rs +++ b/egg-pre-dcp/tests/test_rules.rs @@ -12,6 +12,7 @@ use egg_pre_dcp::test_util::{*}; #[test] fn test_log_le_log() { pre_dcp_check_with_domain( + "log_le_log", vec![("x", domain::pos_dom()), ("y", domain::pos_dom())], "0", vec![ @@ -22,6 +23,7 @@ fn test_log_le_log() { #[test] fn test_sub_iff_add_le() { pre_dcp_check( + "sub_iff_add_le", "0", vec![ "(le (add 1 (var x)) (var x))", @@ -31,6 +33,7 @@ fn test_sub_iff_add_le() { #[test] fn test_log_le_log_rev() { pre_dcp_check( + "log_le_log_rev", "0", vec![ "(le (exp (var x)) (exp (var y)))" @@ -40,6 +43,7 @@ fn test_log_le_log_rev() { #[test] fn test_exp_add() { pre_dcp_check_with_domain( + "exp_add", vec![("x", domain::pos_dom())], "0", vec![ @@ -50,6 +54,7 @@ fn test_exp_add() { #[test] fn test_exp_neg_eq_one_div_obj() { pre_dcp_check_with_domain( + "exp_neg_eq_one_div_obj", vec![("x", Domain::make_ci(domain::one()))], "(mul (var x) (exp (neg (log (var x)))))", vec![ @@ -59,6 +64,7 @@ fn test_exp_neg_eq_one_div_obj() { #[test] fn test_exp_neg_eq_one_div_constr() { pre_dcp_check_with_domain( + "exp_neg_eq_one_div_constr", vec![("x", Domain::make_ci(domain::one()))], "(le (mul (var x) (exp (neg (log (var x))))) (var x))", vec![ @@ -68,6 +74,7 @@ fn test_exp_neg_eq_one_div_constr() { #[test] fn test_log_mul_rev_constr() { pre_dcp_check_with_domain( + "log_mul_rev_constr", vec![("x", domain::pos_dom())], "0", vec![ @@ -78,6 +85,7 @@ fn test_log_mul_rev_constr() { #[test] fn test_exp_neg_eq_one_div_rev() { pre_dcp_check( + "exp_neg_eq_one_div_rev", "(div 1 (exp (var x)))", vec![ "(le 1 (var x))" @@ -87,6 +95,7 @@ fn test_exp_neg_eq_one_div_rev() { #[test] fn test_div_self() { pre_dcp_check_with_domain( + "div_self", vec![("x", domain::pos_dom())], "0", vec![ @@ -97,6 +106,7 @@ fn test_div_self() { #[test] fn test_div_le_iff_rev() { pre_dcp_check_with_domain( + "div_le_iff_rev", vec![("x", domain::pos_dom())], "0", vec![ @@ -107,6 +117,7 @@ fn test_div_le_iff_rev() { #[test] fn test_log_div_rev_obj() { pre_dcp_check_with_domain( + "log_div_rev_obj", vec![("x", domain::pos_dom())], "(neg (sub (log (pow (var x) 2)) (log (var x))))", vec![ @@ -116,6 +127,7 @@ fn test_log_div_rev_obj() { #[test] fn test_geo_mean_fold() { pre_dcp_check_expression_with_domain( + "geo_mean_fold", vec![("x", domain::pos_dom()), ("y", domain::pos_dom())], "(neg (sqrt (mul (var x) (var y))))"); } @@ -123,6 +135,7 @@ fn test_geo_mean_fold() { #[test] fn test_quad_over_lin_fold() { pre_dcp_check_expression_with_domain( + "quad_over_lin_fold", vec![("x", domain::free_dom()), ("y", domain::pos_dom())], "(div (pow (var x) 2) (var y))"); } @@ -130,6 +143,7 @@ fn test_quad_over_lin_fold() { #[test] fn test_norm2_fold() { pre_dcp_check_expression_with_domain( + "norm2_fold", vec![("x", domain::free_dom()), ("y", domain::free_dom())], "(sqrt (add (pow (var x) 2) (pow (var y) 2)))"); } diff --git a/egg-pre-dcp/tests/test_stanford.rs b/egg-pre-dcp/tests/test_stanford.rs index c1e39f12..7e1d5fc3 100644 --- a/egg-pre-dcp/tests/test_stanford.rs +++ b/egg-pre-dcp/tests/test_stanford.rs @@ -13,8 +13,9 @@ use egg_pre_dcp::test_util::{*}; fn test_stan1() { // log(e^(2x + 3) + e^(4y + 5)) = lse(2x + 3, 4y + 5) pre_dcp_check_expression_with_domain_and_print( + "stan1", vec![("x", domain::free_dom()), ("y", domain::free_dom())], - "(log (add (exp (add (mul 2 (var x)) 3)) (exp (add (mul 4 (var y)) 5))))"); + "(log (add (exp (add (mul 2 (var x)) 3)) (exp (add (mul 4 (var y)) 5))))"); } // Exercise 3.38 (e). @@ -23,24 +24,27 @@ fn test_stan2() { // (sqrt(x) + sqrt(y))^2 = x + y + 2sqrt(xy) // ... = x + y + 2geo(x, y) pre_dcp_check_expression_with_domain_and_print( + "stan2", vec![("x", domain::pos_dom()), ("y", domain::pos_dom())], - "(neg (pow (add (sqrt (var x)) (sqrt (var y))) 2))"); + "(neg (pow (add (sqrt (var x)) (sqrt (var y))) 2))"); } // Exercise 3.67 with n = 3. #[test] fn test_stan3() { pre_dcp_check_expression_with_domain_and_print( + "stan3", vec![("x", domain::pos_dom()), ("y", domain::pos_dom()), ("z", domain::pos_dom())], - "(neg (pow (add (sqrt (var x)) (add (sqrt (var y)) (sqrt (var z)))) 2))"); + "(neg (pow (add (sqrt (var x)) (add (sqrt (var y)) (sqrt (var z)))) 2))"); } // Exercise 3.67 with n = 4. #[test] fn test_stan4() { pre_dcp_check_expression_with_domain_and_print( + "stan4", vec![("x", domain::pos_dom()), ("y", domain::pos_dom()), ("z", domain::pos_dom()), ("w", domain::pos_dom())], - "(neg (pow (add (sqrt (var x)) (add (sqrt (var y)) (add (sqrt (var z)) (sqrt (var w))))) 2))"); + "(neg (pow (add (sqrt (var x)) (add (sqrt (var y)) (add (sqrt (var z)) (sqrt (var w))))) 2))"); } // TODO: Failing because of qol curvature check simplification. @@ -85,6 +89,7 @@ fn test_stan_hard_aux(n: usize, node_limit: usize) { // Generalizaiton of 3.28. Works for n = 3,4,5,6,7 // (sqrt(x_1) + ... + sqrt(x_n))^2 // ... = sum_{i <= n} x_i + 2 * sum_{i < j <= n} geo(x_i, x_j) + let prob_name = format!("stan_hard_{}", n); let build_domain = |n| { if n < 2 { panic!("n must be >= 2"); @@ -118,7 +123,8 @@ fn test_stan_hard_aux(n: usize, node_limit: usize) { .iter() .map(|(s,d)| (s.as_str(), d.clone())) .collect::>(); - pre_dcp_check_expression_with_domain_and_node_limit(domain, &build_term(n), node_limit); + pre_dcp_check_expression_with_domain_and_node_limit( + &prob_name, domain, &build_term(n), node_limit); } // Exercise 3.67. diff --git a/lakefile.lean b/lakefile.lean index 0102f20c..d40f0c64 100644 --- a/lakefile.lean +++ b/lakefile.lean @@ -30,19 +30,22 @@ lean_lib CvxLean @[default_target] lean_lib CvxLeanTest -def compileCargo (name : String) (manifestFile : FilePath) (cargo : FilePath := "cargo") : - LogIO Unit := do +def compileCargo (name : String) (manifestFile : FilePath) (cargo : FilePath := "cargo") + (env : Array (String × Option String)) : LogIO Unit := do logInfo s!"Creating {name}" proc { + env := env cmd := cargo.toString args := #["build", "--release", "--manifest-path", manifestFile.toString] } def buildCargo (targetFile : FilePath) (manifestFile : FilePath) (targetDest : FilePath) - (oFileJobs : Array (BuildJob FilePath)) : SchedulerM (BuildJob FilePath) := + (oFileJobs : Array (BuildJob FilePath)) (stopOnSuccess : Bool) : + SchedulerM (BuildJob FilePath) := let name := targetFile.fileName.getD targetFile.toString buildFileAfterDepArray targetFile oFileJobs fun _ => do - compileCargo name manifestFile + let env := if stopOnSuccess then #[("RUSTFLAGS", some "--cfg stop_on_success")] else #[] + compileCargo name manifestFile (env := env) createParentDirs targetDest proc { cmd := "cp" @@ -55,7 +58,15 @@ target EggPreDCP (pkg) : FilePath := do let binFile := buildDir / "target" / "release" / "egg-pre-dcp" let dest := buildDir / "utils" / "egg-pre-dcp" let manifestFile := buildDir / "Cargo.toml" - buildCargo binFile manifestFile dest #[] + buildCargo binFile manifestFile dest #[] false + +@[default_target] +target EggPreDCPStopOnSuccess (pkg) : FilePath := do + let buildDir := pkg.dir / "egg-pre-dcp" + let binFile := buildDir / "target" / "release" / "egg-pre-dcp" + let dest := buildDir / "utils" / "egg-pre-dcp" + let manifestFile := buildDir / "Cargo.toml" + buildCargo binFile manifestFile dest #[] true script EggClean := do let targetDir : FilePath := "." / "egg-pre-dcp" / "target" diff --git a/scripts/evaluation/egg-pre-dcp-options/README.md b/scripts/evaluation/egg-pre-dcp-options/README.md index 816379ad..198da1ef 100644 --- a/scripts/evaluation/egg-pre-dcp-options/README.md +++ b/scripts/evaluation/egg-pre-dcp-options/README.md @@ -1,3 +1,6 @@ +WARNING: These scripts are obsolete, there is better reporting infrastructure now. +They will be removed. + Here, we compare `egg-pre-dcp` with different options (iterative node limits and stop-on-success). Run this from the root of the repository. diff --git a/scripts/evaluation/lean-pre-dcp/build_egg_pre_dcp_stop_on_success.sh b/scripts/evaluation/lean-pre-dcp/build_egg_pre_dcp_stop_on_success.sh index acf66d6f..c0704db9 100755 --- a/scripts/evaluation/lean-pre-dcp/build_egg_pre_dcp_stop_on_success.sh +++ b/scripts/evaluation/lean-pre-dcp/build_egg_pre_dcp_stop_on_success.sh @@ -1,5 +1,5 @@ #!/bin/bash -RUSTFLAGS="--cfg stop_on_success" cargo build --release --manifest-path egg-pre-dcp/Cargo.toml +RUSTFLAGS="--cfg stop_on_success" cargo build --release --manifest-path egg-pre-dcp/Cargo.toml mkdir -p egg-pre-dcp/utils cp egg-pre-dcp/target/release/egg-pre-dcp egg-pre-dcp/utils/egg-pre-dcp