Skip to content

Commit

Permalink
feat: split egg-pre-DCP rewrite workloads by component (#32)
Browse files Browse the repository at this point in the history
* style: line length in `rules.rs`

* wip: report structure?

* wip: iterate through components

* feat: keep track of problem name

* feat: `Report` structure

* feat: serialize `Curvature`

* feat: full reports

* chore: note on parallelization

* fix: imports in `main.rs` and return a hashmap

* fix: add names to tests

* fix: do not fail on first component added to the report

* fix: control flow issue, split by component working

* feat: count full problem structure in term size

* feat: start adjusting Lean side to new steps format

* feat: add name to `EggRequest`

* fix: tag `response` in `Response`

* fix: add location directly in explanation

* chore: note about congruence rule

* chore: warning about evaluation scripts

* feat: enable `lake build EggPreDCPStopOnSuccess`

* fix: correctly pass Rust flags to environment
  • Loading branch information
ramonfmir committed Apr 25, 2024
1 parent 46087d0 commit 25b66f2
Show file tree
Hide file tree
Showing 23 changed files with 709 additions and 252 deletions.
1 change: 1 addition & 0 deletions CvxLean/Tactic/PreDCP/Egg/EggTypes.lean
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ def EggOCTreeExtended := EggOCTree × Array EggDomainIdentified

/-- A request consists of an `EggMinimization` and a list of domains per variable (or parameter). -/
structure EggRequest where
probName : String
domains : List (String × EggDomain)
target : EggMinimization

Expand Down
46 changes: 26 additions & 20 deletions CvxLean/Tactic/PreDCP/Egg/Runner.lean
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@ NOTE: Tuples are lists of two elements. -/
def EggRequest.toJson (e : EggRequest) : String :=
"{" ++
surroundQuotes "request" ++ " : " ++ surroundQuotes "PerformRewrite" ++ ", " ++
surroundQuotes "prob_name" ++ " : " ++ surroundQuotes e.probName ++ ", " ++
surroundQuotes "domains" ++ " : " ++
"[" ++
(", ".intercalate <| e.domains.map (fun domain =>
Expand Down Expand Up @@ -118,7 +119,8 @@ def runEggRequestRaw (requestJson : String) : MetaM String := do
return stdout

/-- Read `egg`'s output and trun it into an array of `EggRewrite`s. -/
def parseEggResponse (responseString : String) : MetaM (Array EggRewrite) := do
def parseEggResponse (responseString : String) :
MetaM (HashMap String (Array EggRewrite)) := do
dbg_trace s!"Egg response: {responseString}"
let outJson : Json ← match Json.parse responseString with
| Except.error e => throwError (s!"error calling `egg`, JSON parsing error ({e}).")
Expand All @@ -133,29 +135,33 @@ def parseEggResponse (responseString : String) : MetaM (Array EggRewrite) := do

else
let steps ← liftExcept <| outJson.getObjVal? "steps"
let steps ← liftExcept <| Json.getArr? steps

let res := steps.map fun step =>
let rewriteName := (step.getObjValD "rewrite_name").getStr!
let direction := match (step.getObjValD "direction").getStr! with
| "Forward" => EggRewriteDirection.Forward
| "Backward" => EggRewriteDirection.Backward
| _ => panic! "Unexpected rewrite direction."
let location := (step.getObjValD "location").getStr!
let subexprFrom := (step.getObjValD "subexpr_from").getStr!
let subexprTo := (step.getObjValD "subexpr_to").getStr!
let expectedTerm := (step.getObjValD "expected_term").getStr!
{ rewriteName := rewriteName,
direction := direction,
location := location,
subexprFrom := subexprFrom,
subexprTo := subexprTo,
expectedTerm := expectedTerm }
let steps ← liftExcept <| Json.getObj? steps

let mut res := HashMap.empty
for ⟨componentName, componentSteps⟩ in steps.toArray do
let componentSteps ← liftExcept <| Json.getArr? componentSteps
let componentStepsParsed : Array EggRewrite := componentSteps.map fun step =>
let rewriteName := (step.getObjValD "rewrite_name").getStr!
let direction := match (step.getObjValD "direction").getStr! with
| "Forward" => EggRewriteDirection.Forward
| "Backward" => EggRewriteDirection.Backward
| _ => panic! "Unexpected rewrite direction."
let location := (step.getObjValD "location").getStr!
let subexprFrom := (step.getObjValD "subexpr_from").getStr!
let subexprTo := (step.getObjValD "subexpr_to").getStr!
let expectedTerm := (step.getObjValD "expected_term").getStr!
{ rewriteName := rewriteName,
direction := direction,
location := location,
subexprFrom := subexprFrom,
subexprTo := subexprTo,
expectedTerm := expectedTerm }
res := res.insert componentName componentStepsParsed

return res

/-- Run request to `egg` and parse the output to get an array of rewrites, if successful. -/
def runEggRequest (request : EggRequest) : MetaM (Array EggRewrite) :=
def runEggRequest (request : EggRequest) : MetaM (HashMap String (Array EggRewrite)) :=
dbg_trace s!"Running egg request: {request.toJson}"
runEggRequestRaw request.toJson >>= parseEggResponse

Expand Down
29 changes: 18 additions & 11 deletions CvxLean/Tactic/PreDCP/PreDCP.lean
Original file line number Diff line number Diff line change
Expand Up @@ -207,30 +207,37 @@ def preDCPBuilder : EquivalenceBuilder Unit := fun eqvExpr g => g.withContext do
-- Prepare `egg` request.
let eggMinimization := EggMinimization.ofEggOCTree gStr
let eggRequest : EggRequest :=
{ domains := (varDomainConstrs ++ paramDomains).data,
{ probName := "lean_prob",
domains := (varDomainConstrs ++ paramDomains).data,
target := eggMinimization }

try
-- Call `egg` (time it for evaluation).
let bef ← BaseIO.toIO IO.monoMsNow
let steps ← runEggRequest eggRequest
let stepsByComponent ← runEggRequest eggRequest
let aft ← BaseIO.toIO IO.monoMsNow
let diff := aft - bef
dbg_trace s!"Egg time: {diff} ms."
dbg_trace s!"Number of steps: {steps.size}."
dbg_trace s!"Term size: {probSize}."
dbg_trace s!"Term JSON: {eggMinimization.toJson}."

-- Apply steps.
let mut stepsCount := 0
let mut g := g
for step in steps do
let gs ← Tactic.run g <| (evalStep step varsNames paramsNames paramsDecls tagsMap).toTactic
if gs.length != 1 then
trace[CvxLean.debug] "Remaining goals: {gs}."
throwPreDCPError "failed to rewrite {step.rewriteName} ({gs.length} goals remaining)."
else
trace[CvxLean.debug] "Rewrote {step.rewriteName}."
g := gs[0]!
for (_componentName, steps) in stepsByComponent do
-- TODO: Since rewrites are now split by component, we could apply the corresponding
-- congruence rule just once.
for step in steps do
stepsCount := stepsCount + 1
let gs ← Tactic.run g <| (evalStep step varsNames paramsNames paramsDecls tagsMap).toTactic
if gs.length != 1 then
trace[CvxLean.debug] "Remaining goals: {gs}."
throwPreDCPError "failed to rewrite {step.rewriteName} ({gs.length} goals remaining)."
else
trace[CvxLean.debug] "Rewrote {step.rewriteName}."
g := gs[0]!

dbg_trace s!"Number of steps: {stepsCount}."

let gsFinal ← evalTacticAt (← `(tactic| equivalence_rfl)) g
if gsFinal.length != 0 then
Expand Down
7 changes: 7 additions & 0 deletions egg-pre-dcp/src/curvature.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
use core::cmp::Ordering;
use std::fmt;
use serde::Serialize;

use crate::domain;
use domain::Domain as Domain;
Expand Down Expand Up @@ -78,6 +79,12 @@ impl fmt::Display for Curvature {
}
}

impl Serialize for Curvature {
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error> where S: serde::Serializer {
serializer.serialize_str(&self.to_string())
}
}

pub fn join(c1: Curvature, c2: Curvature) -> Curvature {
if c1 <= c2 { c2 } else if c2 <= c1 { c1 } else { Unknown }
}
Expand Down
82 changes: 81 additions & 1 deletion egg-pre-dcp/src/explain_util.rs
Original file line number Diff line number Diff line change
@@ -1,13 +1,93 @@
use egg::{*};
use std::collections::HashMap;
use serde::Serialize;
use serde::{Deserialize, Serialize};

use crate::optimization;
use optimization::Optimization as Optimization;
use optimization::Meta as Meta;

pub type Rewrite = egg::Rewrite<Optimization, Meta>;

// Representation of an optimization problem by its components.

#[derive(Deserialize, Debug)]
pub struct Minimization {
pub obj_fun : String,
pub constrs : Vec<(String, String)>,
}

impl ToString for Minimization {
fn to_string(&self) -> String {
let obj_fun_s: String = format!("(objFun {})", self.obj_fun);
let constrs_s_l : Vec<String> =
self.constrs.iter().map(
|(h, c)| format!("(constr {} {})", h, c)).collect();
let constr_s = format!("(constrs {})", constrs_s_l.join(" "));
return format!("(prob {} {})", obj_fun_s, constr_s);
}
}

impl Serialize for Minimization {
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error> where S: serde::Serializer {
serializer.serialize_str(&self.to_string())
}
}

pub struct MinimizationIter {
min_iter : Vec<(String, String)>,
}

impl Minimization {
pub fn iter(&self) -> MinimizationIter {
let obj_fun_s = format!("(objFun {})", self.obj_fun);
let mut min_iter = vec![("objFun".to_string(), obj_fun_s.clone())];
min_iter.append(&mut self.constrs.clone());
MinimizationIter { min_iter }
}
}

impl Iterator for MinimizationIter {
type Item = (String, String);

fn next(&mut self) -> Option<Self::Item> {
if self.min_iter.is_empty() {
None
} else {
Some(self.min_iter.remove(0))
}
}
}

// Since sometimes we want to rewrite a full optimization problem, and sometimes just an expression.
#[derive(Deserialize, Debug)]
pub enum MinimizationOrExpr {
Min(Minimization),
Expr(String),
}

impl Serialize for MinimizationOrExpr {
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error> where S: serde::Serializer {
match self {
MinimizationOrExpr::Min(min) => min.serialize(serializer),
MinimizationOrExpr::Expr(expr) => expr.serialize(serializer),
}
}
}

impl MinimizationOrExpr {
pub fn iter(&self) -> MinimizationIter {
match self {
MinimizationOrExpr::Min(min) => min.iter(),
MinimizationOrExpr::Expr(expr) => {
let min_iter = vec![("expr".to_string(), expr.clone())];
MinimizationIter { min_iter }
}
}
}
}

// Reading flat terms.

#[derive(Clone, Copy, Serialize, Debug)]
pub enum Direction {
Forward,
Expand Down
Loading

0 comments on commit 25b66f2

Please sign in to comment.