Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: split egg-pre-DCP rewrite workloads by component #32

Merged
merged 21 commits into from
Apr 25, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
21 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CvxLean/Tactic/PreDCP/Egg/EggTypes.lean
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ def EggOCTreeExtended := EggOCTree × Array EggDomainIdentified

/-- A request consists of an `EggMinimization` and a list of domains per variable (or parameter). -/
structure EggRequest where
probName : String
domains : List (String × EggDomain)
target : EggMinimization

Expand Down
46 changes: 26 additions & 20 deletions CvxLean/Tactic/PreDCP/Egg/Runner.lean
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@ NOTE: Tuples are lists of two elements. -/
def EggRequest.toJson (e : EggRequest) : String :=
"{" ++
surroundQuotes "request" ++ " : " ++ surroundQuotes "PerformRewrite" ++ ", " ++
surroundQuotes "prob_name" ++ " : " ++ surroundQuotes e.probName ++ ", " ++
surroundQuotes "domains" ++ " : " ++
"[" ++
(", ".intercalate <| e.domains.map (fun domain =>
Expand Down Expand Up @@ -118,7 +119,8 @@ def runEggRequestRaw (requestJson : String) : MetaM String := do
return stdout

/-- Read `egg`'s output and trun it into an array of `EggRewrite`s. -/
def parseEggResponse (responseString : String) : MetaM (Array EggRewrite) := do
def parseEggResponse (responseString : String) :
MetaM (HashMap String (Array EggRewrite)) := do
dbg_trace s!"Egg response: {responseString}"
let outJson : Json ← match Json.parse responseString with
| Except.error e => throwError (s!"error calling `egg`, JSON parsing error ({e}).")
Expand All @@ -133,29 +135,33 @@ def parseEggResponse (responseString : String) : MetaM (Array EggRewrite) := do

else
let steps ← liftExcept <| outJson.getObjVal? "steps"
let steps ← liftExcept <| Json.getArr? steps

let res := steps.map fun step =>
let rewriteName := (step.getObjValD "rewrite_name").getStr!
let direction := match (step.getObjValD "direction").getStr! with
| "Forward" => EggRewriteDirection.Forward
| "Backward" => EggRewriteDirection.Backward
| _ => panic! "Unexpected rewrite direction."
let location := (step.getObjValD "location").getStr!
let subexprFrom := (step.getObjValD "subexpr_from").getStr!
let subexprTo := (step.getObjValD "subexpr_to").getStr!
let expectedTerm := (step.getObjValD "expected_term").getStr!
{ rewriteName := rewriteName,
direction := direction,
location := location,
subexprFrom := subexprFrom,
subexprTo := subexprTo,
expectedTerm := expectedTerm }
let steps ← liftExcept <| Json.getObj? steps

let mut res := HashMap.empty
for ⟨componentName, componentSteps⟩ in steps.toArray do
let componentSteps ← liftExcept <| Json.getArr? componentSteps
let componentStepsParsed : Array EggRewrite := componentSteps.map fun step =>
let rewriteName := (step.getObjValD "rewrite_name").getStr!
let direction := match (step.getObjValD "direction").getStr! with
| "Forward" => EggRewriteDirection.Forward
| "Backward" => EggRewriteDirection.Backward
| _ => panic! "Unexpected rewrite direction."
let location := (step.getObjValD "location").getStr!
let subexprFrom := (step.getObjValD "subexpr_from").getStr!
let subexprTo := (step.getObjValD "subexpr_to").getStr!
let expectedTerm := (step.getObjValD "expected_term").getStr!
{ rewriteName := rewriteName,
direction := direction,
location := location,
subexprFrom := subexprFrom,
subexprTo := subexprTo,
expectedTerm := expectedTerm }
res := res.insert componentName componentStepsParsed

return res

/-- Run request to `egg` and parse the output to get an array of rewrites, if successful. -/
def runEggRequest (request : EggRequest) : MetaM (Array EggRewrite) :=
def runEggRequest (request : EggRequest) : MetaM (HashMap String (Array EggRewrite)) :=
dbg_trace s!"Running egg request: {request.toJson}"
runEggRequestRaw request.toJson >>= parseEggResponse

Expand Down
29 changes: 18 additions & 11 deletions CvxLean/Tactic/PreDCP/PreDCP.lean
Original file line number Diff line number Diff line change
Expand Up @@ -207,30 +207,37 @@ def preDCPBuilder : EquivalenceBuilder Unit := fun eqvExpr g => g.withContext do
-- Prepare `egg` request.
let eggMinimization := EggMinimization.ofEggOCTree gStr
let eggRequest : EggRequest :=
{ domains := (varDomainConstrs ++ paramDomains).data,
{ probName := "lean_prob",
domains := (varDomainConstrs ++ paramDomains).data,
target := eggMinimization }

try
-- Call `egg` (time it for evaluation).
let bef ← BaseIO.toIO IO.monoMsNow
let steps ← runEggRequest eggRequest
let stepsByComponent ← runEggRequest eggRequest
let aft ← BaseIO.toIO IO.monoMsNow
let diff := aft - bef
dbg_trace s!"Egg time: {diff} ms."
dbg_trace s!"Number of steps: {steps.size}."
dbg_trace s!"Term size: {probSize}."
dbg_trace s!"Term JSON: {eggMinimization.toJson}."

-- Apply steps.
let mut stepsCount := 0
let mut g := g
for step in steps do
let gs ← Tactic.run g <| (evalStep step varsNames paramsNames paramsDecls tagsMap).toTactic
if gs.length != 1 then
trace[CvxLean.debug] "Remaining goals: {gs}."
throwPreDCPError "failed to rewrite {step.rewriteName} ({gs.length} goals remaining)."
else
trace[CvxLean.debug] "Rewrote {step.rewriteName}."
g := gs[0]!
for (_componentName, steps) in stepsByComponent do
-- TODO: Since rewrites are now split by component, we could apply the corresponding
-- congruence rule just once.
for step in steps do
stepsCount := stepsCount + 1
let gs ← Tactic.run g <| (evalStep step varsNames paramsNames paramsDecls tagsMap).toTactic
if gs.length != 1 then
trace[CvxLean.debug] "Remaining goals: {gs}."
throwPreDCPError "failed to rewrite {step.rewriteName} ({gs.length} goals remaining)."
else
trace[CvxLean.debug] "Rewrote {step.rewriteName}."
g := gs[0]!

dbg_trace s!"Number of steps: {stepsCount}."

let gsFinal ← evalTacticAt (← `(tactic| equivalence_rfl)) g
if gsFinal.length != 0 then
Expand Down
7 changes: 7 additions & 0 deletions egg-pre-dcp/src/curvature.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
use core::cmp::Ordering;
use std::fmt;
use serde::Serialize;

use crate::domain;
use domain::Domain as Domain;
Expand Down Expand Up @@ -78,6 +79,12 @@ impl fmt::Display for Curvature {
}
}

impl Serialize for Curvature {
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error> where S: serde::Serializer {
serializer.serialize_str(&self.to_string())
}
}

pub fn join(c1: Curvature, c2: Curvature) -> Curvature {
if c1 <= c2 { c2 } else if c2 <= c1 { c1 } else { Unknown }
}
Expand Down
82 changes: 81 additions & 1 deletion egg-pre-dcp/src/explain_util.rs
Original file line number Diff line number Diff line change
@@ -1,13 +1,93 @@
use egg::{*};
use std::collections::HashMap;
use serde::Serialize;
use serde::{Deserialize, Serialize};

use crate::optimization;
use optimization::Optimization as Optimization;
use optimization::Meta as Meta;

pub type Rewrite = egg::Rewrite<Optimization, Meta>;

// Representation of an optimization problem by its components.

#[derive(Deserialize, Debug)]
pub struct Minimization {
pub obj_fun : String,
pub constrs : Vec<(String, String)>,
}

impl ToString for Minimization {
fn to_string(&self) -> String {
let obj_fun_s: String = format!("(objFun {})", self.obj_fun);
let constrs_s_l : Vec<String> =
self.constrs.iter().map(
|(h, c)| format!("(constr {} {})", h, c)).collect();
let constr_s = format!("(constrs {})", constrs_s_l.join(" "));
return format!("(prob {} {})", obj_fun_s, constr_s);
}
}

impl Serialize for Minimization {
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error> where S: serde::Serializer {
serializer.serialize_str(&self.to_string())
}
}

pub struct MinimizationIter {
min_iter : Vec<(String, String)>,
}

impl Minimization {
pub fn iter(&self) -> MinimizationIter {
let obj_fun_s = format!("(objFun {})", self.obj_fun);
let mut min_iter = vec![("objFun".to_string(), obj_fun_s.clone())];
min_iter.append(&mut self.constrs.clone());
MinimizationIter { min_iter }
}
}

impl Iterator for MinimizationIter {
type Item = (String, String);

fn next(&mut self) -> Option<Self::Item> {
if self.min_iter.is_empty() {
None
} else {
Some(self.min_iter.remove(0))
}
}
}

// Since sometimes we want to rewrite a full optimization problem, and sometimes just an expression.
#[derive(Deserialize, Debug)]
pub enum MinimizationOrExpr {
Min(Minimization),
Expr(String),
}

impl Serialize for MinimizationOrExpr {
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error> where S: serde::Serializer {
match self {
MinimizationOrExpr::Min(min) => min.serialize(serializer),
MinimizationOrExpr::Expr(expr) => expr.serialize(serializer),
}
}
}

impl MinimizationOrExpr {
pub fn iter(&self) -> MinimizationIter {
match self {
MinimizationOrExpr::Min(min) => min.iter(),
MinimizationOrExpr::Expr(expr) => {
let min_iter = vec![("expr".to_string(), expr.clone())];
MinimizationIter { min_iter }
}
}
}
}

// Reading flat terms.

#[derive(Clone, Copy, Serialize, Debug)]
pub enum Direction {
Forward,
Expand Down
Loading
Loading