diff --git a/.gitignore b/.gitignore index 2aa5a114..4c143eca 100644 --- a/.gitignore +++ b/.gitignore @@ -10,7 +10,7 @@ *.code-workspace /playground -/evaluation +/scripts/evaluation/data /plots /CvxLean/Tactic/Solver/Mosek/Path.lean diff --git a/CvxLean/Command/Equivalence.lean b/CvxLean/Command/Equivalence.lean index 1e5038c8..5ac28d75 100644 --- a/CvxLean/Command/Equivalence.lean +++ b/CvxLean/Command/Equivalence.lean @@ -40,9 +40,11 @@ type `Equivalence p q`. -/ def elabEquivalenceProof (lhs : Expr) (rhsName : Name) (stx : Syntax) : TermElabM (Expr × Expr) := elabTransformationProof TransformationGoal.Equivalence lhs rhsName stx +/-- Open an equivalence environment. -/ syntax (name := equivalence) "equivalence" ident "/" ident declSig ":=" Lean.Parser.Term.byTactic : command +/-- Open an equivalence environment and try to generate a computable backward map. -/ syntax (name := equivalenceAndBwdMap) "equivalence'" ident "/" ident declSig ":=" Lean.Parser.Term.byTactic : command diff --git a/CvxLean/Command/Reduction.lean b/CvxLean/Command/Reduction.lean index 1f2461c9..96dba4b0 100644 --- a/CvxLean/Command/Reduction.lean +++ b/CvxLean/Command/Reduction.lean @@ -27,7 +27,7 @@ environment: Writing `reduction'` instead of `reduction` will also generate a backward solution map at the level of floats. -It is essentially the same as `Command/Equivalence.lean`, except that the goal is to prove a +It is essentially the same as `CvxLean/Command/Equivalence.lean`, except that the goal is to prove a reduction instead of an equivalence. We note that proving that two problems are equivalent is usually preferred. -/ @@ -45,9 +45,11 @@ type `Reduction p q`. -/ def elabReductionProof (lhs : Expr) (rhsName : Name) (stx : Syntax) : TermElabM (Expr × Expr) := elabTransformationProof TransformationGoal.Reduction lhs rhsName stx +/-- Open a reduction environment. -/ syntax (name := reduction) "reduction" ident "/" ident declSig ":=" Lean.Parser.Term.byTactic : command +/-- Open a reduction environment and try to generate a computable backward map. -/ syntax (name := reductionAndBwdMap) "reduction'" ident "/" ident declSig ":=" Lean.Parser.Term.byTactic : command diff --git a/CvxLean/Command/Relaxation.lean b/CvxLean/Command/Relaxation.lean index 46239702..47d45e82 100644 --- a/CvxLean/Command/Relaxation.lean +++ b/CvxLean/Command/Relaxation.lean @@ -20,10 +20,10 @@ environment: * `q := r`, * `rel : p ≽' q`. -This command is very similar to `equivalence` (`Command/Equivalence.lean`) and `reduction` -(`Command/Reduction.lean`) in how it is designed. Of course, the key difference is that the goal -is to prove a relaxation. Note that in this case, there is no option to create a backward map as -relaxations do not guarantee that the solution can be mapped back. +This command is very similar to `equivalence` (`CvxLean/Command/Equivalence.lean`) and `reduction` +(`CvxLean/Command/Reduction.lean`) in how it is designed. Of course, the key difference is that the +goal is to prove a relaxation. Note that in this case, there is no option to create a backward map +as relaxations do not guarantee that the solution can be mapped back. -/ namespace CvxLean @@ -39,10 +39,11 @@ type `Relaxation p q`. -/ def elabRelaxationProof (lhs : Expr) (rhsName : Name) (stx : Syntax) : TermElabM (Expr × Expr) := elabTransformationProof TransformationGoal.Relaxation lhs rhsName stx +/-- Open a relaxation environment. -/ syntax (name := relaxation) "relaxation" ident "/" ident declSig ":=" Lean.Parser.Term.byTactic : command -/-- Relaxation command. -/ +/-- Definition of the `relaxation` command. -/ @[command_elab «relaxation»] def evalRelaxation : CommandElab := fun stx => match stx with | `(relaxation $relIdStx / $probIdStx $declSig := $proofStx) => do diff --git a/CvxLean/Command/Solve.lean b/CvxLean/Command/Solve.lean index 23fce36c..6481dfaf 100644 --- a/CvxLean/Command/Solve.lean +++ b/CvxLean/Command/Solve.lean @@ -27,7 +27,8 @@ def getProblemName (stx : Syntax) : MetaM Name := do if ¬ idStx.getId.isStr then throwError "Invalid name for minimization problem: {idStx}." - return idStx.getId + let currNamespace ← getCurrNamespace + return currNamespace ++ idStx.getId /-- Call DCP and get the problem in conic form as well as `ψ`, the backward map from the equivalence. -/ diff --git a/CvxLean/Command/Solve/Float/Coeffs.lean b/CvxLean/Command/Solve/Float/Coeffs.lean index 904c7c0c..aba0c2e0 100644 --- a/CvxLean/Command/Solve/Float/Coeffs.lean +++ b/CvxLean/Command/Solve/Float/Coeffs.lean @@ -10,10 +10,10 @@ import CvxLean.Command.Solve.Float.RealToFloatLibrary # Extract coefficients from problem to generate problem data This file defines `determineCoeffsFromExpr`, which takes a `MinimizationExpr` and returns -`ProblemData`. This procedure is used by `Commands/Solve.lean` to be able to call an external -solver. +`ProblemData`. This procedure is used by `CvxLean/Commands/Solve.lean` to be able to call an +external solver. -## TODO +### TODO * This is probably a big source of inefficency for the `solve` command. We should come up with a better way to extract the numerical values from the Lean expressions. diff --git a/CvxLean/Command/Solve/Float/FloatToReal.lean b/CvxLean/Command/Solve/Float/FloatToReal.lean index 4799d5f2..90ba5631 100644 --- a/CvxLean/Command/Solve/Float/FloatToReal.lean +++ b/CvxLean/Command/Solve/Float/FloatToReal.lean @@ -2,8 +2,8 @@ import Mathlib.Data.Real.Basic import CvxLean.Lib.Math.Data.Real /-! -Conversion used in `Solve/Conic.lean` to read the solver's output into an expression to which we can -apply the backward map. +Conversion used in `CvxLean/Command/Solve/Conic.lean` to read the solver's output into an expression +to which we can apply the backward map. -/ namespace CvxLean diff --git a/CvxLean/Command/Solve/Float/ProblemData.lean b/CvxLean/Command/Solve/Float/ProblemData.lean index d7c675b9..d625b316 100644 --- a/CvxLean/Command/Solve/Float/ProblemData.lean +++ b/CvxLean/Command/Solve/Float/ProblemData.lean @@ -6,7 +6,7 @@ We provide an interface to build this data by adding any of the possible types o data can then be transformed in a straightforward manner into Conic Benchmark Format, for example. The procedure that creates `ProblemData` from an optimization problem can be found in -`Command/Solve/Coeffs.lean`. +`CvxLean/Command/Solve/Coeffs.lean`. -/ namespace CvxLean @@ -115,7 +115,7 @@ def addZeroConstraint (data : ProblemData) (a : Array Float) (b : Float) : Probl /-- Add exponential cone constraint `∑ i, aᵢxᵢ + b ∈ 𝒦ₑ` to problem data. Note that the second-order cone is `3`-dimensional, so to capture `(x, y, z) ∈ 𝒦ₑ` we do `x ∈ 𝒦ₑ`, `y ∈ 𝒦ₑ`, and `z ∈ 𝒦ₑ` consecutively. We keep track of how to group consecutive constraints in -`Command/Solve/Float/Coeffs.lean`. -/ +`CvxLean/Command/Solve/Float/Coeffs.lean`. -/ def addExpConstraint (data : ProblemData) (a : Array Float) (b : Float) : ProblemData := data.addScalarAffineConstraintOnlyVector a b ScalarConeType.Exp diff --git a/CvxLean/Command/Solve/Float/RealToFloatCmd.lean b/CvxLean/Command/Solve/Float/RealToFloatCmd.lean index b81c9009..c1d6cb6b 100644 --- a/CvxLean/Command/Solve/Float/RealToFloatCmd.lean +++ b/CvxLean/Command/Solve/Float/RealToFloatCmd.lean @@ -14,7 +14,7 @@ import CvxLean.Command.Solve.Float.RealToFloatExt # Conversion of Real to Float (command) We define the `realToFloat` function to go from real `Expr`s to float `Expr`s as well as the -`addRealToFloat` command to add new translations and the `#realToFloat` command to test the +`add_real_to_float` command to add new translations and the `#real_to_float` command to test the translation. -/ @@ -23,7 +23,7 @@ namespace CvxLean open Lean Expr Elab Meta Command Term syntax (name := addRealToFloatCommand) - "addRealToFloat" Lean.Parser.Term.funBinder* ":" term ":=" term : command + "add_real_to_float" Lean.Parser.Term.funBinder* ":" term ":=" term : command /-- Traverse expression recursively and if a match is found in the real-to-float library return the float version. The "correctness" of this function depends on the translations defined in the @@ -72,11 +72,11 @@ partial def realToFloat (e : Expr) : MetaM Expr := do return e | _ => return e -/-- The `addRealToFloat` command, which simply adds the user-defined expressions to the environment -extension. -/ +/-- The `add_real_to_float` command, which simply adds the user-defined expressions to the +environment extension. -/ @[command_elab addRealToFloatCommand] def elabAddRealToFloatCommand : CommandElab -| `(addRealToFloat : $real := $float) => +| `(add_real_to_float : $real := $float) => liftTermElabM do let real ← elabTermAndSynthesize real.raw none let float ← elabTermAndSynthesize float.raw none @@ -84,19 +84,19 @@ def elabAddRealToFloatCommand : CommandElab | _ => throwUnsupportedSyntax macro_rules -| `(addRealToFloat $idents:funBinder* : $real := $float) => do +| `(add_real_to_float $idents:funBinder* : $real := $float) => do if idents.size != 0 then - `(addRealToFloat : fun $idents:funBinder* => $real := fun $idents:funBinder* => $float) + `(add_real_to_float : fun $idents:funBinder* => $real := fun $idents:funBinder* => $float) else Macro.throwUnsupported syntax (name:=realToFloatCommand) - "#realToFloat" term : command + "#real_to_float" term : command /-- Transform the given expression to its float version and log the result. -/ @[command_elab realToFloatCommand] unsafe def elabRealToFloatCommand : CommandElab -| `(#realToFloat $stx) => +| `(#real_to_float $stx) => liftTermElabM do let e ← elabTermAndSynthesize stx.raw none let res ← realToFloat e diff --git a/CvxLean/Command/Solve/Float/RealToFloatLibrary.lean b/CvxLean/Command/Solve/Float/RealToFloatLibrary.lean index e9620c58..822b28dc 100644 --- a/CvxLean/Command/Solve/Float/RealToFloatLibrary.lean +++ b/CvxLean/Command/Solve/Float/RealToFloatLibrary.lean @@ -14,200 +14,200 @@ namespace CvxLean section Basic -addRealToFloat : Real := +add_real_to_float : Real := Float -addRealToFloat : Real.instInhabitedReal := +add_real_to_float : Real.instInhabitedReal := instInhabitedFloat -addRealToFloat : Real.instZeroReal := +add_real_to_float : Real.instZeroReal := Zero.mk (0 : Float) -addRealToFloat : Real.instOneReal := +add_real_to_float : Real.instOneReal := One.mk (1 : Float) -addRealToFloat : Real.instLEReal := +add_real_to_float : Real.instLEReal := instLEFloat -addRealToFloat : Real.instLTReal := +add_real_to_float : Real.instLTReal := instLTFloat -addRealToFloat : Real.instDivReal := +add_real_to_float : Real.instDivReal := instDivFloat -addRealToFloat : Real.instPowReal := +add_real_to_float : Real.instPowReal := Pow.mk Float.pow -addRealToFloat (n : Nat) (i) : @OfNat.ofNat Real n i := +add_real_to_float (n : Nat) (i) : @OfNat.ofNat Real n i := Float.ofNat n -addRealToFloat (n : Nat) : AddMonoidWithOne.toNatCast.natCast (R := ℝ) n := +add_real_to_float (n : Nat) : AddMonoidWithOne.toNatCast.natCast (R := ℝ) n := Float.ofNat n -addRealToFloat (i) (x : ℕ) : @Nat.cast Real i x := +add_real_to_float (i) (x : ℕ) : @Nat.cast Real i x := Float.ofNat x -addRealToFloat (k : Nat) : +add_real_to_float (k : Nat) : @SMul.smul ℕ ℝ AddMonoid.toNatSMul k := (fun (x : Float) => (OfNat.ofNat k) * x) -addRealToFloat (i) : @Ring.toNeg Real i := +add_real_to_float (i) : @Ring.toNeg Real i := Neg.mk Float.neg -addRealToFloat : Real.instNegReal := instNegFloat +add_real_to_float : Real.instNegReal := instNegFloat -addRealToFloat : Real.instAddReal := instAddFloat +add_real_to_float : Real.instAddReal := instAddFloat -addRealToFloat (i) : @HAdd.hAdd Real Real Real i := +add_real_to_float (i) : @HAdd.hAdd Real Real Real i := Float.add -addRealToFloat (i) : @instHAdd Real i := +add_real_to_float (i) : @instHAdd Real i := @HAdd.mk Float Float Float Float.add -addRealToFloat : Real.instSubReal := instSubFloat +add_real_to_float : Real.instSubReal := instSubFloat -addRealToFloat (i) : @HSub.hSub Real Real Real i := +add_real_to_float (i) : @HSub.hSub Real Real Real i := Float.sub -addRealToFloat (i) : @instHSub Real i := +add_real_to_float (i) : @instHSub Real i := @HSub.mk Float Float Float Float.sub -addRealToFloat : Real.instMulReal := instMulFloat +add_real_to_float : Real.instMulReal := instMulFloat -addRealToFloat (i) : @HMul.hMul Real Real Real i := +add_real_to_float (i) : @HMul.hMul Real Real Real i := Float.mul -addRealToFloat (i) : @instHMul Real i := +add_real_to_float (i) : @instHMul Real i := @HMul.mk Float Float Float Float.mul -addRealToFloat (i) (k) : @HSMul.hSMul ℕ ℝ ℝ i k := +add_real_to_float (i) (k) : @HSMul.hSMul ℕ ℝ ℝ i k := fun r => Float.ofNat k * r -addRealToFloat (i) : @HDiv.hDiv Real Real Real i := +add_real_to_float (i) : @HDiv.hDiv Real Real Real i := Float.div -addRealToFloat (i) : @instHDiv Real i := +add_real_to_float (i) : @instHDiv Real i := @HDiv.mk Float Float Float Float.div -addRealToFloat (i) : @HPow.hPow Real Nat Real i := +add_real_to_float (i) : @HPow.hPow Real Nat Real i := fun f n => Float.pow f (Float.ofNat n) -addRealToFloat (i) : @HPow.hPow.{0, 0, 0} Real Real Real i := +add_real_to_float (i) : @HPow.hPow.{0, 0, 0} Real Real Real i := fun f n => Float.pow f n -addRealToFloat (β) (i) : @instHPow Real β i := +add_real_to_float (β) (i) : @instHPow Real β i := @HPow.mk Float Float Float Float.pow -addRealToFloat (n) (i) : @HPow.hPow (Fin n → Real) Real (Fin n → Real) i := +add_real_to_float (n) (i) : @HPow.hPow (Fin n → Real) Real (Fin n → Real) i := fun (x : Fin n → Float) (p : Float) (i : Fin n) => Float.pow (x i) p -addRealToFloat (n) (β) (i) : @instHPow (Fin n → Real) β i := +add_real_to_float (n) (β) (i) : @instHPow (Fin n → Real) β i := @HPow.mk (Fin n → Float) Float (Fin n → Float) (fun x p i => Float.pow (x i) p) -addRealToFloat (i) : @LE.le Real i := +add_real_to_float (i) : @LE.le Real i := Float.le -addRealToFloat : Real.pi := +add_real_to_float : Real.pi := 2 * Float.acos 0 -addRealToFloat : @Real.exp := +add_real_to_float : @Real.exp := Float.exp -addRealToFloat : @Vec.exp.{0} := +add_real_to_float : @Vec.exp.{0} := @Vec.Computable.exp.{0} -addRealToFloat : @Real.sqrt := +add_real_to_float : @Real.sqrt := Float.sqrt -addRealToFloat : @Real.log := +add_real_to_float : @Real.log := Float.log -addRealToFloat (n) (i) : @Norm.norm.{0} (Fin n → ℝ) i := +add_real_to_float (n) (i) : @Norm.norm.{0} (Fin n → ℝ) i := @Real.Computable.norm n -addRealToFloat (i) : @OfScientific.ofScientific Real i := +add_real_to_float (i) : @OfScientific.ofScientific Real i := Float.ofScientific -addRealToFloat : Real.natCast := +add_real_to_float : Real.natCast := NatCast.mk Float.ofNat end Basic section Matrix -addRealToFloat (n m k) : +add_real_to_float (n m k) : @HSMul.hSMul ℕ (Matrix (Fin n) (Fin m) ℝ) (Matrix (Fin n) (Fin m) ℝ) instHSMul k := (fun (M : Matrix (Fin n) (Fin m) Float) i j => (OfNat.ofNat k) * (M i j)) -addRealToFloat (n k i) : @HSMul.hSMul ℕ ((Fin n) → ℝ) ((Fin n) → ℝ) i k := +add_real_to_float (n k i) : @HSMul.hSMul ℕ ((Fin n) → ℝ) ((Fin n) → ℝ) i k := (fun (x : (Fin n) → Float) i => (OfNat.ofNat k) * (x i)) -addRealToFloat (n m k : Nat) : @SMul.smul ℕ (Matrix (Fin n) (Fin m) ℝ) AddMonoid.toNatSMul k := +add_real_to_float (n m k : Nat) : @SMul.smul ℕ (Matrix (Fin n) (Fin m) ℝ) AddMonoid.toNatSMul k := (fun (M : Matrix (Fin n) (Fin m) Float) i j => (OfNat.ofNat k) * (M i j)) -addRealToFloat (n k i) : @SMul.smul ℕ ((Fin n) → ℝ) i k := +add_real_to_float (n k i) : @SMul.smul ℕ ((Fin n) → ℝ) i k := (fun (x : (Fin n) → Float) i => (OfNat.ofNat k) * (x i)) -addRealToFloat (i1 i2 i3) : @Algebra.toSMul ℝ ℝ i1 i2 i3 := +add_real_to_float (i1 i2 i3) : @Algebra.toSMul ℝ ℝ i1 i2 i3 := SMul.mk Float.mul -addRealToFloat : @Matrix.vecEmpty Real := +add_real_to_float : @Matrix.vecEmpty Real := fun (x : Fin 0) => ((False.elim (Nat.not_lt_zero x.1 x.2)) : Float) -addRealToFloat (n) : @Matrix.vecEmpty (Fin n → Real) := +add_real_to_float (n) : @Matrix.vecEmpty (Fin n → Real) := fun (x : Fin 0) => ((False.elim (Nat.not_lt_zero x.1 x.2)) : Fin n → Float) -addRealToFloat : @Matrix.transpose.{0,0,0} := +add_real_to_float : @Matrix.transpose.{0,0,0} := @Matrix.Computable.transpose.{0,0,0} -addRealToFloat (n) (α) (i1) (i2) : @Matrix.diagonal.{0,0} (Fin n) α i1 i2 := +add_real_to_float (n) (α) (i1) (i2) : @Matrix.diagonal.{0,0} (Fin n) α i1 i2 := @Matrix.Computable.diagonal n -addRealToFloat : @Matrix.fromBlocks.{0,0,0,0,0} := +add_real_to_float : @Matrix.fromBlocks.{0,0,0,0,0} := @Matrix.Computable.fromBlocks -addRealToFloat : @Matrix.diag.{0,0} := +add_real_to_float : @Matrix.diag.{0,0} := @Matrix.Computable.diag.{0,0} -addRealToFloat (n) (i) (hn) : @Vec.sum.{0} ℝ i (Fin n) hn := +add_real_to_float (n) (i) (hn) : @Vec.sum.{0} ℝ i (Fin n) hn := @Vec.Computable.sum n -addRealToFloat (n) (i) (hn) : @Matrix.sum (Fin n) Real hn i := +add_real_to_float (n) (i) (hn) : @Matrix.sum (Fin n) Real hn i := @Matrix.Computable.sum n -addRealToFloat (n) (i) : @Vec.cumsum.{0} ℝ i n := +add_real_to_float (n) (i) : @Vec.cumsum.{0} ℝ i n := @Vec.Computable.cumsum n -addRealToFloat : @Vec.norm := +add_real_to_float : @Vec.norm := @Vec.Computable.norm -addRealToFloat (n) (i1) (i2) (i3) : @Matrix.dotProduct (Fin n) ℝ i1 i2 i3 := +add_real_to_float (n) (i1) (i2) (i3) : @Matrix.dotProduct (Fin n) ℝ i1 i2 i3 := @Matrix.Computable.dotProduct n -addRealToFloat (n m) (i1) (i2) : @Matrix.mulVec (Fin n) (Fin m) ℝ i1 i2 := +add_real_to_float (n m) (i1) (i2) : @Matrix.mulVec (Fin n) (Fin m) ℝ i1 i2 := @Matrix.Computable.mulVec n m -addRealToFloat (n m) (i1) (i2) : @Matrix.vecMul (Fin n) (Fin m) ℝ i1 i2 := +add_real_to_float (n m) (i1) (i2) : @Matrix.vecMul (Fin n) (Fin m) ℝ i1 i2 := @Matrix.Computable.vecMul n m -addRealToFloat (n : Nat) (i1) (i2) : @Matrix.trace (Fin n) ℝ i1 i2 := +add_real_to_float (n : Nat) (i1) (i2) : @Matrix.trace (Fin n) ℝ i1 i2 := @Matrix.Computable.trace n -addRealToFloat (l m n) (i) : +add_real_to_float (l m n) (i) : @HMul.hMul (Matrix (Fin l) (Fin m) ℝ) (Matrix (Fin m) (Fin n) ℝ) (Matrix (Fin m) (Fin n) ℝ) i := @Matrix.Computable.mul l m n -addRealToFloat (n : Nat) (i1) (i2) : @Matrix.toUpperTri.{0,0} (Fin n) ℝ i1 i2 := +add_real_to_float (n : Nat) (i1) (i2) : @Matrix.toUpperTri.{0,0} (Fin n) ℝ i1 i2 := @Matrix.Computable.toUpperTri n -addRealToFloat (n) (i) : @Inv.inv (Matrix (Fin n) (Fin n) ℝ) i := +add_real_to_float (n) (i) : @Inv.inv (Matrix (Fin n) (Fin n) ℝ) i := @Matrix.Computable.inv n end Matrix section CovarianceEstimation -addRealToFloat (N n : ℕ) : @covarianceMatrix N n := +add_real_to_float (N n : ℕ) : @covarianceMatrix N n := @Matrix.Computable.covarianceMatrix N n end CovarianceEstimation diff --git a/CvxLean/Command/Solve/Mosek/CBF.lean b/CvxLean/Command/Solve/Mosek/CBF.lean index d9eaea96..27b9fe84 100644 --- a/CvxLean/Command/Solve/Mosek/CBF.lean +++ b/CvxLean/Command/Solve/Mosek/CBF.lean @@ -25,7 +25,7 @@ end ObjSense /-- Cone type: free, positive orthant, negative orthant, fixpoint zero (for equality), quadratic, quadratic rotated or exponential. The quadratic cone is called second-order in other parts of the project. Note that this definition is independent from the cone types in -`Command/Solve/Float/ProblemData.lean`. -/ +`CvxLean/Command/Solve/Float/ProblemData.lean`. -/ inductive ConeType | F | LPos | LNeg | LEq | Q | QR | EXP deriving DecidableEq diff --git a/CvxLean/Command/Solve/Mosek/CBFOfProblemData.lean b/CvxLean/Command/Solve/Mosek/CBFOfProblemData.lean index d6d83f8f..fc3dfbaa 100644 --- a/CvxLean/Command/Solve/Mosek/CBFOfProblemData.lean +++ b/CvxLean/Command/Solve/Mosek/CBFOfProblemData.lean @@ -57,7 +57,8 @@ unsafe def getTotalDim (minExpr : MinimizationExpr) : MetaM Nat := do namespace CBF -/-- Translate generic cone types from `Solve/Float/ProblemData.lean` to MOSEK cones. -/ +/-- Translate generic cone types from `CvxLean/Command/Solve/Float/ProblemData.lean` to MOSEK +cones. -/ def translateCone : ScalarConeType → CBF.ConeType | ScalarConeType.Zero => CBF.ConeType.LEq | ScalarConeType.PosOrth => CBF.ConeType.LPos @@ -67,8 +68,8 @@ def translateCone : ScalarConeType → CBF.ConeType /-- Constraints in non-matrix cones with intrinsic dimension (`𝒦ₑ`, `𝒬ⁿ⁺¹`, and `𝒬ᵣⁿ⁺²`) are defined by a series of constraints. These need to be grouped appropriately. In -`Command/Solve/Float/Coeffs.lean` we keep track of these groups as "sections". These are used here -to correctly write the problem in CBF. -/ +`CvxLean/Command/Solve/Float/Coeffs.lean` we keep track of these groups as "sections". These are +used here to correctly write the problem in CBF. -/ def groupCones (sections : ScalarAffineSections) (l : List CBF.Cone) : MetaM (List CBF.Cone) := do let l := l.toArray let mut res := [] diff --git a/CvxLean/Demos/WIP.lean b/CvxLean/Demos/WIP.lean deleted file mode 100644 index 28855549..00000000 --- a/CvxLean/Demos/WIP.lean +++ /dev/null @@ -1,194 +0,0 @@ -import CvxLean - -open CvxLean Minimization Real - -noncomputable section - -namespace Chapter3 - -namespace CvxLean_overview - -def p := - optimization (x y : ℝ) - minimize 40 * x + 30 * y - subject to - c₁ : 12 ≤ x + y - c₂ : 16 ≤ 2 * x + y - -example : p = - Minimization.mk - (fun p => 40 * p.1 + 30 * p.2) - (fun p => 12 ≤ p.1 + p.2 ∧ 16 ≤ 2 * p.1 + p.2) := - rfl - -def pSolution : Solution p := - { point := ⟨4, 8⟩, - isOptimal := by - split_ands <;> try norm_num - intros x' y' h_feas; simp [feasible, p] at h_feas ⊢; linarith } - -end CvxLean_overview - -namespace Equivalences - -def pex₁ := Minimization.mk (fun (x : ℝ) => x) (fun x => 1 ≤ x) -def pex₂ := Minimization.mk (fun (x : ℝ) => log x) (fun x => 1 ≤ x) -def pex₃ := Minimization.mk (fun (x : ℝ) => (log x) ^ (2 : ℝ)) (fun x => 1 ≤ x) - -def E₁₂ : pex₁ ≡ pex₂ := Equivalence.map_objFun_log (fun x h => by positivity!) - -def E₂₃ : pex₂ ≡ pex₃ := Equivalence.map_objFun_sq (fun x h => by positivity!) - -def nS₁₂ : (pex₁ ≡' pex₂) → False := - fun ⟨_, psi, _, psi_feasibility, _, psi_optimality⟩ => by - simp [pex₁, pex₂, feasible, constraints, objFun] at * - let y : ℝ := 1 - have h_feas_y : 1 ≤ y := by norm_num - have h_one_le_phiy := psi_feasibility y h_feas_y - have h_psiy_le_logy := psi_optimality y h_feas_y - have h_one_le_logy := le_trans h_one_le_phiy h_psiy_le_logy - simp at h_one_le_logy - have hc := lt_of_le_of_lt h_one_le_logy (zero_lt_one (α:= ℝ)) - exact lt_irrefl _ hc - -def S₂₃ : pex₂ ≡' pex₃ := - { phi := fun x => (exp (sqrt (log x))) - psi := fun x => (exp ((log x) ^ 2)) - phi_feasibility := fun x _ => by - simp [pex₂, pex₃, feasible, constraints, objFun] at * - exact sqrt_nonneg _, - psi_feasibility := fun x h_feas_x => by - simp [pex₂, pex₃, feasible, constraints, objFun] at * - exact sq_nonneg _, - phi_optimality := fun x h_feas_x => by - simp [pex₂, pex₃, feasible, constraints, objFun] at * - have h_logx_nonneg : 0 ≤ log x := log_nonneg h_feas_x - rw [sq_sqrt h_logx_nonneg], - psi_optimality := fun x _ => by - simp [pex₂, pex₃, feasible, constraints, objFun] at * } - -def p₄ := Minimization.mk (fun (x : ℝ) => x) (fun x => 0 ≤ x) - -def S₄₂ : p₄ ≡' pex₂ := - { phi := fun x => exp x - psi := fun x => log x - phi_feasibility := fun x _ => by - simp [p₄, pex₂, feasible, constraints, objFun] at * - positivity!, - psi_feasibility := fun x h_feas_x => by - simp [p₄, pex₂, feasible, constraints, objFun] at * - positivity!, - phi_optimality := fun x _ => by - simp [p₄, pex₂, feasible, constraints, objFun] at *, - psi_optimality := fun x _ => by - simp [p₄, pex₂, feasible, constraints, objFun] at *, } - -def nS₄₁ : (p₄ ≡' pex₁) → False := fun h => nS₁₂ <| h.symm.trans S₄₂ - -equivalence eqv₁/p₂ : - optimization (x y : ℝ) - minimize (x * y) - subject to - c₁ : 0 < x - c₂ : 1 ≤ y := by - change_of_variables! (u) (x ↦ exp u) - change_of_variables! (v) (y ↦ exp v) - -equivalence eqv₂/p₃ : p₂ := by - conv_obj => - rw [← exp_add] - conv_constr c₂ => - rw [← log_le_log (by norm_num) (exp_pos _), log_exp] - norm_num - -open Equivalence - -set_option trace.Meta.debug true - -equivalence eqv/q : - optimization (x : ℝ) - minimize sqrt (x ^ 2) - subject to - c₁ : 0 ≤ x := by - rw_obj => - rw [sqrt_sq c₁] - -end Equivalences - -namespace Relaxations - -def p₁ := - optimization (x y : ℝ) - minimize x + y - subject to - c₁ : 0 ≤ x - c₂ : 0 ≤ y - -def p₂ := - optimization (x : ℝ) - minimize x - subject to - c₁ : 0 ≤ x - -def Rx₁₂ : p₁ ≽' p₂ := - { phi := fun (x, _) => x, - phi_feasibility := fun _ ⟨hx, _⟩ => hx, - phi_optimality := fun _ ⟨_, hy⟩ => by simpa [objFun, p₁, p₂] } - -open Real Matrix - -lemma PosSemiDef_is_symmetric {n} {A : Matrix (Fin n) (Fin n) ℝ} (hA : PosSemidef A) : - ∀ i j, A i j = A j i := by - let h_A_IsHermitian := hA.1 - rw [Matrix.isHermitian_iff_isSymmetric] at h_A_IsHermitian - simp [LinearMap.IsSymmetric, toEuclideanLin] at h_A_IsHermitian - intros i j - have hAij := h_A_IsHermitian (fun k => if k = i then 1 else 0) (fun k => if k = j then 1 else 0) - have hi : (Finset.sum Finset.univ fun x => A j x * (Equiv.refl (WithLp 2 (Fin n → ℝ))) - (fun k => if k = i then 1 else 0) x) = A j i := by simp - have hj : (Finset.sum Finset.univ fun x => A i x * (Equiv.refl (WithLp 2 (Fin n → ℝ))) - (fun k => if k = j then 1 else 0) x) = A i j := by simp - simp [WithLp.equiv, mulVec, dotProduct] at hAij - erw [hi, hj] at hAij - rw [hAij] - -lemma trace_mul_transpose_self_eq_quad_of_symm {n} (A : Matrix (Fin n) (Fin n) ℝ) (x : Fin n → ℝ) - (h_A_symm : ∀ i j, A i j = A j i) : - trace (A * (Vec.toMatrix x * (Vec.toMatrix x)ᵀ)) = vecMul x A ⬝ᵥ x := by - simp [trace, dotProduct] - congr; funext i; - simp [Matrix.mul_apply, vecMul, dotProduct, Finset.sum_mul] - congr; funext j; - simp [Vec.toMatrix] - rw [h_A_symm i j] - ring - -def sdr {n} {A C : Matrix (Fin n) (Fin n) ℝ} (hA : PosSemidef A) (hC : PosSemidef C) (b : ℝ) : - (optimization (x : Fin n → ℝ) - minimize (vecMul x C) ⬝ᵥ x - subject to - c₁ : (vecMul x A) ⬝ᵥ x ≥ b) ≽' - (optimization (X : Matrix (Fin n) (Fin n) ℝ) - minimize trace (C * X) - subject to - c₁ : trace (A * X) ≥ b - c₂ : PosSemidef X) := - { phi := fun x => (Vec.toMatrix x) * (Vec.toMatrix x)ᵀ, - phi_feasibility := fun x h_feas_x => by - simp [objFun, feasible, constraints, p₁, p₂, Vec.toMatrix] at * - have h_A_symm := PosSemiDef_is_symmetric hA - rw [trace_mul_transpose_self_eq_quad_of_symm A x h_A_symm] - refine ⟨h_feas_x, ?_⟩ - rw [← conjTranspose_eq_transpose] - exact posSemidef_conjTranspose_mul_self _, - phi_optimality := fun x h_feas_x => by - simp [objFun, feasible, constraints, p₁, p₂] at * - have h_C_symm := PosSemiDef_is_symmetric hC - rw [trace_mul_transpose_self_eq_quad_of_symm C x h_C_symm] } - - -end Relaxations - -end Chapter3 - -end diff --git a/CvxLean/Examples/FittingSphere.lean b/CvxLean/Examples/FittingSphere.lean index a1aae3c3..7d614b6e 100644 --- a/CvxLean/Examples/FittingSphere.lean +++ b/CvxLean/Examples/FittingSphere.lean @@ -102,16 +102,11 @@ equivalence' eqv/fittingSphereT (n m : ℕ) (x : Fin m → Fin n → ℝ) : fitt conv_constr h₁ => dsimp conv_obj => dsimp -- Rewrite objective. - equivalence_step => - apply Equivalence.rewrite_objFun - (g := fun (ct : (Fin n → ℝ) × ℝ) => - Vec.sum (((Vec.norm x) ^ 2 - 2 * (Matrix.mulVec x ct.1) - Vec.const m ct.2) ^ 2)) - . rintro ⟨c, t⟩ h - dsimp at h ⊢; simp [Vec.sum, Vec.norm, Vec.const]; congr; funext i; congr 1; - rw [norm_sub_sq (𝕜 := ℝ) (E := Fin n → ℝ), sq_sqrt (rpow_two _ ▸ le_of_lt (sqrt_pos.mp h))] - simp [mulVec, inner, dotProduct] - rename_vars [c, t] - + rw_obj into (Vec.sum (((Vec.norm x) ^ 2 - 2 * (Matrix.mulVec x c) - Vec.const m t) ^ 2)) => + dsimp at h₁ ⊢; simp [Vec.sum, Vec.norm, Vec.const]; congr; funext i; congr 1; + rw [norm_sub_sq (𝕜 := ℝ) (E := Fin n → ℝ), sq_sqrt (rpow_two _ ▸ le_of_lt (sqrt_pos.mp h₁))] + simp [mulVec, inner, dotProduct] + #print fittingSphereT -- optimization (c : Fin n → ℝ) (t : ℝ) -- minimize Vec.sum ((Vec.norm x ^ 2 - 2 * mulVec x c - Vec.const m t) ^ 2) diff --git a/CvxLean/Examples/HypersonicShapeDesign.lean b/CvxLean/Examples/HypersonicShapeDesign.lean index 5b11bd3c..e7580962 100644 --- a/CvxLean/Examples/HypersonicShapeDesign.lean +++ b/CvxLean/Examples/HypersonicShapeDesign.lean @@ -91,22 +91,20 @@ equivalence' eqv₂/hypersonicShapeDesignSimpler (a b : ℝ) (ha : 0 ≤ a) (hb rewrite [le_sqrt' (by norm_num)]; norm_num conv_constr h₂ => rewrite [sqrt_le_iff]; norm_num - rw_constr h₃ => + rw_constr h₃ into (a ^ 2 * z⁻¹ ≤ (1 - b) ^ 2 * (1 - z)) => have hz : 0 ≤ z := by arith have h_one_sub_z : 0 ≤ 1 - z := by arith - rewrite [rpow_two (sqrt a), sq_sqrt ha, rpow_two (sqrt z), sq_sqrt hz] - rewrite [div_le_iff (by arith)] + rw [rpow_two (sqrt a), sq_sqrt ha, rpow_two (sqrt z), sq_sqrt hz] + rw [div_le_iff (by arith)] have hlhs : 0 ≤ a / sqrt z := div_nonneg ha (sqrt_nonneg _) have hrhs : 0 ≤ sqrt (1 - z) * (1 - b) := mul_nonneg (sqrt_nonneg _) (by arith) - rewrite [← pow_two_le_pow_two hlhs hrhs] - rewrite [div_rpow ha (sqrt_nonneg _), rpow_two (sqrt z), sq_sqrt hz] - rewrite [mul_rpow (sqrt_nonneg _) (by arith), rpow_two (sqrt (1 - z)), sq_sqrt h_one_sub_z] - rewrite [← mul_one_div, ← inv_eq_one_div, mul_comm (1 - z) _] - rfl + rw [← pow_two_le_pow_two hlhs hrhs] + rw [div_rpow ha (sqrt_nonneg _), rpow_two (sqrt z), sq_sqrt hz] + rw [mul_rpow (sqrt_nonneg _) (by arith), rpow_two (sqrt (1 - z)), sq_sqrt h_one_sub_z] + rw [← mul_one_div, ← inv_eq_one_div, mul_comm (1 - z) _] rename_constrs [h₁, h₂, h₃] - rw_obj => - rewrite [rpow_neg (sqrt_nonneg _), rpow_two (sqrt z), sq_sqrt (by arith)] - rfl + rw_obj into (z⁻¹ - 1) => + rw [rpow_neg (sqrt_nonneg _), rpow_two (sqrt z), sq_sqrt (by arith)] #print hypersonicShapeDesignSimpler -- optimization (z : ℝ) diff --git a/CvxLean/Examples/TrussDesign.lean b/CvxLean/Examples/TrussDesign.lean index bcfe5134..3fdd3ce9 100644 --- a/CvxLean/Examples/TrussDesign.lean +++ b/CvxLean/Examples/TrussDesign.lean @@ -56,7 +56,6 @@ instance : ChangeOfVariables simp; rw [mul_comm _ (R ^ 2 - r ^ 2), ← mul_div, div_self (by positivity), mul_one] ring_nf; exact sqrt_sq hR } -set_option maxHeartbeats 1000000 in equivalence' eqv₁/trussDesignGP (hmin hmax wmin wmax Rmax σ F₁ F₂ : ℝ) : trussDesign hmin hmax wmin wmax Rmax σ F₁ F₂ := by -- Apply key change of variables. @@ -72,44 +71,28 @@ equivalence' eqv₁/trussDesignGP (hmin hmax wmin wmax Rmax σ F₁ F₂ : ℝ) -- Some cleanup. conv_opt => dsimp -- Rewrite contraint `c_R_lb`. - equivalence_step => - apply Equivalence.rewrite_constraint_8 - (c8' := fun (hwrA : ℝ × ℝ × ℝ × ℝ) => 0.21 * hwrA.2.2.1 ^ 2 ≤ hwrA.2.2.2 / (2 * π)) - . rintro ⟨h, w, r, A⟩ c_r _ _ _ _ _ _ _; dsimp at * - rw [le_sqrt' (by positivity), rpow_two, mul_pow, ← sub_le_iff_le_add] - rw [iff_eq_eq]; congr; ring_nf + rw_constr c_R_lb into (0.21 * r ^ 2 ≤ A / (2 * π)) => + rw [le_sqrt' (by positivity), rpow_two, mul_pow, ← sub_le_iff_le_add] + rw [iff_eq_eq]; congr; ring_nf + -- Useful fact about constraints used for simplification. This is a good example of introducing + -- local lemmas in an `equivalence` environment. have h_A_div_2π_add_r2_nonneg : ∀ (r A : ℝ) (c_r : 0 < r) (c_R_lb : 0.21 * r ^ 2 ≤ A / (2 * π)), 0 ≤ A / (2 * π) + r ^ 2 := fun r A c_r c_R_lb => by have h_A_div_2π_nonneg : 0 ≤ A / (2 * π) := le_trans (by positivity) c_R_lb exact add_nonneg (h_A_div_2π_nonneg) (by positivity) -- Simplify objective function. - equivalence_step => - apply Equivalence.rewrite_objFun - (g := fun (hwrA : ℝ × ℝ × ℝ × ℝ) => 2 * hwrA.2.2.2 * sqrt (hwrA.2.1 ^ 2 + hwrA.1 ^ 2)) - . rintro ⟨h, w, r, A⟩ ⟨c_r, _, _, _, _, _, _, c_R_lb, _⟩; dsimp at * - rw [rpow_two, sq_sqrt (h_A_div_2π_add_r2_nonneg r A c_r c_R_lb)] - ring_nf; field_simp; ring - rename_vars [h, w, r, A] + rw_obj into (2 * A * sqrt (w ^ 2 + h ^ 2)) => + rw [rpow_two, sq_sqrt (h_A_div_2π_add_r2_nonneg r A c_r c_R_lb)] + ring_nf; field_simp; ring -- Simplify constraint `c_F₁`. - equivalence_step => - apply Equivalence.rewrite_constraint_2 - (c2' := fun (hwrA : ℝ × ℝ × ℝ × ℝ) => - F₁ * sqrt (hwrA.2.1 ^ 2 + hwrA.1 ^ 2) / (2 * hwrA.1) ≤ σ * hwrA.2.2.2) - . rintro ⟨h, w, r, A⟩ c_r ⟨_, _, _, _, _, c_R_lb, _⟩; dsimp at * - rw [rpow_two (sqrt (_ + r ^ 2)), sq_sqrt (h_A_div_2π_add_r2_nonneg r A c_r c_R_lb)] - rw [iff_eq_eq]; congr; ring_nf; field_simp; ring - rename_vars [h, w, r, A] + rw_constr c_F₁ into (F₁ * sqrt (w ^ 2 + h ^ 2) / (2 * h) ≤ σ * A) => + rw [rpow_two (sqrt (_ + r ^ 2)), sq_sqrt (h_A_div_2π_add_r2_nonneg r A c_r c_R_lb)] + rw [iff_eq_eq]; congr; ring_nf; field_simp; ring -- Simplify constraint `c_F₂`. - equivalence_step => - apply Equivalence.rewrite_constraint_3 - (c3' := fun (hwrA : ℝ × ℝ × ℝ × ℝ) => - F₂ * sqrt (hwrA.2.1 ^ 2 + hwrA.1 ^ 2) / (2 * hwrA.2.1) ≤ σ * hwrA.2.2.2) - . rintro ⟨h, w, r, A⟩ c_r _ ⟨_, _, _, _, c_R_lb, _⟩; dsimp at * - rw [rpow_two (sqrt (_ + r ^ 2)), sq_sqrt (h_A_div_2π_add_r2_nonneg r A c_r c_R_lb)] - rw [iff_eq_eq]; congr; ring_nf; field_simp; ring - rename_vars [h, w, r, A] - rename_constrs [c_r, c_F₁, c_F₂, c_hmin, c_hmax, c_wmin, c_wmax, c_A_lb, c_A_ub] + rw_constr c_F₂ into (F₂ * sqrt (w ^ 2 + h ^ 2) / (2 * w) ≤ σ * A) => + rw [rpow_two (sqrt (_ + r ^ 2)), sq_sqrt (h_A_div_2π_add_r2_nonneg r A c_r c_R_lb)] + rw [iff_eq_eq]; congr; ring_nf; field_simp; ring #print trussDesignGP -- minimize 2 * A * sqrt (w ^ 2 + h ^ 2) diff --git a/CvxLean/Examples/VehicleSpeedScheduling.lean b/CvxLean/Examples/VehicleSpeedScheduling.lean index 836e0d01..1720ab72 100644 --- a/CvxLean/Examples/VehicleSpeedScheduling.lean +++ b/CvxLean/Examples/VehicleSpeedScheduling.lean @@ -69,23 +69,16 @@ equivalence' eqv₁/vehSpeedSchedConvex (n : ℕ) (d : Fin n → ℝ) conv_constr c_τmin => simp only [Pi.div_apply, simp_vec_fraction d h_d_pos] conv_constr c_τmax => simp only [Pi.div_apply, simp_vec_fraction d h_d_pos] -- Put in matrix form. - equivalence_step => - apply Equivalence.rewrite_objFun (g := fun t => Vec.sum (t * (Vec.map F (d / t)))) - . intro t _; simp [Vec.sum]; rfl - equivalence_step => - apply Equivalence.rewrite_constraint_1 (c1' := fun t => smin ≤ d / t) - . intro t _; rfl - equivalence_step => - apply Equivalence.rewrite_constraint_2 (c2' := fun t => d / t ≤ smax) - . intro t _ _; rfl - equivalence_step => - apply Equivalence.rewrite_constraint_3 (c3' := fun t => τmin ≤ Vec.cumsum t) - . intro t _ _ _; simp [fold_partial_sum t]; rfl - equivalence_step => - apply Equivalence.rewrite_constraint_4_last (c4' := fun t => Vec.cumsum t ≤ τmax) - . intro t _ _ _; simp [fold_partial_sum t]; rfl - rename_vars [t] - rename_constrs [c_smin, c_smax, c_τmin, c_τmax] + rw_obj into (Vec.sum (t * (Vec.map F (d / t)))) => + simp [Vec.sum]; rfl + rw_constr c_smin into (smin ≤ d / t) => + rfl + rw_constr c_smax into (d / t ≤ smax) => + rfl + rw_constr c_τmin into (τmin ≤ Vec.cumsum t) => + simp [fold_partial_sum t]; rfl + rw_constr c_τmax into (Vec.cumsum t ≤ τmax) => + simp [fold_partial_sum t]; rfl #print vehSpeedSchedConvex -- optimization (t : Fin n → ℝ) @@ -120,26 +113,18 @@ equivalence' eqv₂/vehSpeedSchedQuadratic (n : ℕ) (d : Fin n → ℝ) apply Equivalence.add_constraint (cs' := fun t => StrongLT 0 t) . rintro t ⟨c_smin, _⟩ i exact t_pos_of_c_smin t c_smin i + rename_vars [t] rename_constrs [c_t, c_smin, c_smax, c_τmin, c_τmax] -- Arithmetic simplification in the objective function. - equivalence_step => - apply Equivalence.rewrite_objFun - (g := fun t => Vec.sum (a • (d ^ (2 : ℝ)) * (1 / t) + b • d + c • t)) - . rintro t ⟨c_t, _⟩ - congr; funext i; unfold Vec.map; dsimp - have h_ti_pos : 0 < t i := c_t i - field_simp; ring + rw_obj into (Vec.sum (a • (d ^ (2 : ℝ)) * (1 / t) + b • d + c • t)) => + congr; funext i; unfold Vec.map; dsimp + have h_ti_pos : 0 < t i := c_t i + field_simp; ring -- Rewrite linear constraints. - equivalence_step => - apply Equivalence.rewrite_constraint_2 (c2' := fun t => smin * t ≤ d) - . intro t c_t _; rw [Vec.le_div_iff c_t] - rename_vars [t] - rename_constrs [c_t, c_smin, c_smax, c_τmin, c_τmax] - equivalence_step => - apply Equivalence.rewrite_constraint_3 (c3' := fun t => d ≤ smax * t) - . intro t c_t _ _; rw [Vec.div_le_iff c_t] - rename_vars [t] - rename_constrs [c_t, c_smin, c_smax, c_τmin, c_τmax] + rw_constr c_smin into (smin * t ≤ d) => + rw [Vec.le_div_iff c_t] + rw_constr c_smax into (d ≤ smax * t) => + rw [Vec.div_le_iff c_t] -- Finally, we can apply `dcp`! (or we can call `solve`, as we do below). #print vehSpeedSchedQuadratic diff --git a/CvxLean/Lib/Equivalence.lean b/CvxLean/Lib/Equivalence.lean index 5bc28ce2..2f134594 100644 --- a/CvxLean/Lib/Equivalence.lean +++ b/CvxLean/Lib/Equivalence.lean @@ -7,10 +7,10 @@ import CvxLean.Meta.Attributes We define equivalence and strong equivalence, and several equivalence-preserving transformations. -## References +### References -* [S. Boyd and L. Vandenberghe, *Convex Optimization*][BV04] -* [M. C. Grant, *Discipliend Convex Programming*][Gra05] +* [BV04] S. Boyd and L. Vandenberghe, *Convex Optimization* +* [Gra05] M. C. Grant, *Discipliend Convex Programming* -/ namespace Minimization @@ -204,10 +204,11 @@ namespace Equivalence variable {f : D → R} {cs : D → Prop} -/- Mapping the objective function by monotonic functions yields an equivalence. Also, mapping the -whole domain by a function with a right inverse. -/ section Maps +/-! Mapping the objective function by monotonic functions yields an equivalence. Also, mapping the +whole domain by a function with a right inverse. -/ + /-- See [BV04,p.131] where `g` is `ψ₀`. -/ @[equiv] def map_objFun {g : R → R} (h : ∀ {r s}, cs r → cs s → (g (f r) ≤ g (f s) ↔ f r ≤ f s)) : @@ -248,9 +249,10 @@ def map_domain {f : D → R} {cs : D → Prop} {fwd : D → E} {bwd : E → D} end Maps -/- Rewriting the objective function or the constraints leads to equivalent problems. -/ section Rewrites +/-! Rewriting the objective function or the constraints leads to equivalent problems. -/ + variable {f : D → R} variable {c1 c2 c3 c4 c5 c6 c7 c8 c9 c10 : D → Prop} variable {c1' c2' c3' c4' c5' c6' c7' c8' c9' c10' : D → Prop} @@ -481,10 +483,11 @@ def rewrite_constraint_10_last end Rewrites -/- Other equivalence-preserving transformations. These are not used by any tactic but can be used +section Other + +/-! Other equivalence-preserving transformations. These are not used by any tactic but can be used directly. They provide evidence that our notion of equivalence captures the expected transformations. -/ -section Other variable {f : D → R} {cs : D → Prop} diff --git a/CvxLean/Lib/Math/CovarianceEstimation.lean b/CvxLean/Lib/Math/CovarianceEstimation.lean index ff15babe..ef9d623d 100644 --- a/CvxLean/Lib/Math/CovarianceEstimation.lean +++ b/CvxLean/Lib/Math/CovarianceEstimation.lean @@ -1,7 +1,7 @@ import CvxLean.Lib.Math.LinearAlgebra.Matrix.PosDef /-! -Definitions needed for `Examples/CovarianceEstimation.lean`. +Definitions needed for `CvxLean/Examples/CovarianceEstimation.lean`. -/ noncomputable section CovarianceEstimation diff --git a/CvxLean/Lib/Math/Data/Real.lean b/CvxLean/Lib/Math/Data/Real.lean index 127f4ab5..25963dc1 100644 --- a/CvxLean/Lib/Math/Data/Real.lean +++ b/CvxLean/Lib/Math/Data/Real.lean @@ -68,10 +68,11 @@ lemma abs_le_of_sqrt_sq_add_nonneg_le {a b c : ℝ} (hb : 0 ≤ b) replace h := le_trans (le_add_of_nonneg_right hb) h rwa [rpow_two, sq_le_sq, abs_of_nonneg hc] at h -/- Lemmas used in `CvxLean.Tactic.PreDCP.RewriteMapLibrary`. -/ +/- Lemmas used in `CvxLean.Tactic.PreDCP.RuleToTacticLibrary`. -/ -lemma log_eq_log {x y : ℝ} (hx : 0 < x) (hy : 0 < y) : log x = log y ↔ x = y := - ⟨fun h => by +lemma log_eq_log {x y : ℝ} (hx : 0 < x) (hy : 0 < y) : x = y ↔ log x = log y := + ⟨fun h => by rw [h], + fun h => by have hxmem := Set.mem_Ioi.2 hx have hymem := Set.mem_Ioi.2 hy have heq : Set.restrict (Set.Ioi 0) log ⟨x, hxmem⟩ = @@ -79,8 +80,7 @@ lemma log_eq_log {x y : ℝ} (hx : 0 < x) (hy : 0 < y) : log x = log y ↔ x = y simp [h] have h := log_injOn_pos.injective heq simp [Subtype.eq] at h - exact h, - fun h => by rw [h]⟩ + exact h⟩ lemma div_pow_eq_mul_pow_neg {a b c : ℝ} (hb : 0 ≤ b) : a / (b ^ c) = a * b ^ (-c) := by diff --git a/CvxLean/Lib/Math/LinearAlgebra/Matrix/PosDef.lean b/CvxLean/Lib/Math/LinearAlgebra/Matrix/PosDef.lean index a1270e84..ca955387 100644 --- a/CvxLean/Lib/Math/LinearAlgebra/Matrix/PosDef.lean +++ b/CvxLean/Lib/Math/LinearAlgebra/Matrix/PosDef.lean @@ -124,4 +124,19 @@ lemma PosDef_inv_iff_PosDef [DecidableEq n] (M : Matrix n n ℝ) : M⁻¹.PosDef { intros hM exact hM.nonsingular_inv } +lemma PosSemiDef.IsSymm {n} {A : Matrix (Fin n) (Fin n) ℝ} (hA : PosSemidef A) : IsSymm A := by + let h_A_IsHermitian := hA.1 + rw [Matrix.isHermitian_iff_isSymmetric] at h_A_IsHermitian + simp [LinearMap.IsSymmetric, toEuclideanLin] at h_A_IsHermitian + apply IsSymm.ext + intros i j + have hAij := h_A_IsHermitian (fun k => if k = i then 1 else 0) (fun k => if k = j then 1 else 0) + have hi : (Finset.sum Finset.univ fun x => A j x * (Equiv.refl (WithLp 2 (Fin n → ℝ))) + (fun k => if k = i then 1 else 0) x) = A j i := by simp + have hj : (Finset.sum Finset.univ fun x => A i x * (Equiv.refl (WithLp 2 (Fin n → ℝ))) + (fun k => if k = j then 1 else 0) x) = A i j := by simp + simp [WithLp.equiv, mulVec, dotProduct] at hAij + erw [hi, hj] at hAij + rw [hAij] + end Matrix diff --git a/CvxLean/Lib/Reduction.lean b/CvxLean/Lib/Reduction.lean index cbf5778c..bae16a4d 100644 --- a/CvxLean/Lib/Reduction.lean +++ b/CvxLean/Lib/Reduction.lean @@ -7,11 +7,11 @@ We define the notion of reduction. It is a reflexive and transitive relation and backward map between solutions. The idea is that solving the reduced problem is "as hard" as solving the original problem. An equivalence gives a reduction. -## References +### References -* [C. Papadimitriou, *Computational Complexity*][Pap93] +* [Pap93] C. Papadimitriou, *Computational Complexity* -## TODO +### TODO * Build reductions from equivalences automatically in meta. -/ diff --git a/CvxLean/Lib/Relaxation.lean b/CvxLean/Lib/Relaxation.lean index ee17988f..1f330bb0 100644 --- a/CvxLean/Lib/Relaxation.lean +++ b/CvxLean/Lib/Relaxation.lean @@ -7,9 +7,9 @@ We define the notion of relaxation. It is a reflexive and transitive relation an forward map between solutions. The idea is that solving the original problem is "as hard" as solving the relaxed problem. A strong equivalence gives a relaxation. -## References +### References -* [RelaxationWiki] +* [RelaxationWiki] https://en.wikipedia.org/wiki/Relaxation_(approximation) -/ namespace Minimization diff --git a/CvxLean/Meta/TacticBuilder.lean b/CvxLean/Meta/TacticBuilder.lean index 5e123a67..06f4a91b 100644 --- a/CvxLean/Meta/TacticBuilder.lean +++ b/CvxLean/Meta/TacticBuilder.lean @@ -31,15 +31,17 @@ def isTransitive : TransformationGoal → Bool | _ => true def fromExpr (e : Expr) : MetaM TransformationGoal := do - if e.isAppOf `Minimization.Solution then + let e ← whnf e + if e.isAppOf ``Minimization.Solution then return TransformationGoal.Solution - else if e.isAppOf `Minimization.Equivalence then + else if e.isAppOf ``Minimization.Equivalence then return TransformationGoal.Equivalence - else if e.isAppOf `Minimization.Reduction then + else if e.isAppOf ``Minimization.Reduction then return TransformationGoal.Reduction - else if e.isAppOf `Minimization.Relaxation then + else if e.isAppOf ``Minimization.Relaxation then return TransformationGoal.Relaxation else + trace[Meta.TacticBuilder] "{e.ctorName}" throwTacticBuilderError "expected a `Solution`, `Equivalence`, `Reduction` or `Relaxation` goal, got {e}." @@ -67,11 +69,11 @@ end TransformationGoal /-- Given a relaxation goal in the form of a `RelaxationExpr` and the `MVarId` of the current goal, provide a tactic to close it. -/ -def RelaxationBuilder := RelaxationExpr → MVarId → TacticM Unit +def RelaxationBuilder (α) := RelaxationExpr → MVarId → TacticM α namespace RelaxationBuilder -def toTactic (builder : RelaxationBuilder) : TacticM Unit := withMainContext do +def toTactic {α} (builder : RelaxationBuilder α) : TacticM α := withMainContext do let transf ← TransformationGoal.fromExpr (← getMainTarget) -- Apply transitivity. @@ -91,21 +93,23 @@ def toTactic (builder : RelaxationBuilder) : TacticM Unit := withMainContext do -- Run builder. let relExpr ← RelaxationExpr.fromExpr (← gToChange.getType) - builder relExpr gToChange + let res ← builder relExpr gToChange -- Set next goal. gNext.setTag Name.anonymous setGoals [gNext] + return res + end RelaxationBuilder /-- Given a reduction goal in the form of a `ReductionExpr` and the `MVarId` of the current goal, provide a tactic to close it. -/ -def ReductionBuilder := ReductionExpr → MVarId → Tactic +def ReductionBuilder (α) := ReductionExpr → MVarId → TacticM α namespace ReductionBuilder -def toTactic (builder : ReductionBuilder) : Tactic := fun stx => do +def toTactic {α} (builder : ReductionBuilder α) : TacticM α := do let transf ← TransformationGoal.fromExpr (← getMainTarget) -- Apply transitivity. @@ -131,21 +135,23 @@ def toTactic (builder : ReductionBuilder) : Tactic := fun stx => do -- Run builder. let redExpr ← ReductionExpr.fromExpr (← gToChange.getType) - builder redExpr gToChange stx + let res ← builder redExpr gToChange -- Set next goal. gNext.instantiateMVars setGoals [gNext] + return res + end ReductionBuilder /-- Given an equivalence goal in the form of a `EquivalenceExpr` and the `MVarId` of the current goal, provide a tactic to close it. -/ -def EquivalenceBuilder := EquivalenceExpr → MVarId → TacticM Unit +def EquivalenceBuilder (α) := EquivalenceExpr → MVarId → TacticM α namespace EquivalenceBuilder -def toTactic (builder : EquivalenceBuilder) : TacticM Unit := withMainContext do +def toTactic {α} (builder : EquivalenceBuilder α) : TacticM α := withMainContext do let transf ← TransformationGoal.fromExpr (← getMainTarget) -- Apply transitivity. @@ -174,12 +180,14 @@ def toTactic (builder : EquivalenceBuilder) : TacticM Unit := withMainContext do -- Run builder. let eqvExpr ← EquivalenceExpr.fromExpr (← gToChange.getType) - builder eqvExpr gToChange + let res ← builder eqvExpr gToChange -- Set next goal. gNext.setTag Name.anonymous setGoals [gNext] + return res + end EquivalenceBuilder -- For reduction and equivalence commands. diff --git a/CvxLean/Meta/Util/Error.lean b/CvxLean/Meta/Util/Error.lean index 178c7cf7..c8a55a75 100644 --- a/CvxLean/Meta/Util/Error.lean +++ b/CvxLean/Meta/Util/Error.lean @@ -4,12 +4,21 @@ import Lean Custom error messages. -/ +section Syntax + +/-- Throws an error coming from parsing the optimization problem using the custom syntax defined in +`CvxLean/Syntax/Minimization.lean`. -/ syntax "throwParserError " (interpolatedStr(term) <|> term) : term macro_rules | `(throwParserError $msg:interpolatedStr) => `(throwError ("Parser error: " ++ (m! $msg))) | `(throwParserError $msg:term) => `(throwError ("Parser error: " ++ $msg)) +end Syntax + +section Meta + +/-- Throws an error coming from the tactic builder machinery (`CvxLean/Meta/TacticBuilder.lean`). -/ syntax "throwTacticBuilderError " (interpolatedStr(term) <|> term) : term macro_rules @@ -17,50 +26,160 @@ macro_rules `(throwError ("Tactic builder error: " ++ (m! $msg))) | `(throwTacticBuilderError $msg:term) => `(throwError ("Tactic builder error: " ++ $msg)) +end Meta + +section BasicTactics + +/-- Throws an error coming from the `clean_up_comp` tactic. -/ +syntax "throwCleanUpCompError " (interpolatedStr(term) <|> term) : term + +macro_rules + | `(throwCleanUpCompError $msg:interpolatedStr) => + `(throwError ("`clean_up_comp` error: " ++ (m! $msg))) + | `(throwCleanUpCompError $msg:term) => `(throwError ("`clean_up_comp` error: " ++ $msg)) + +/-- Throws an error coming from the `change_of_variables` tactic. -/ +syntax "throwChangeOfVariablesError " (interpolatedStr(term) <|> term) : term + +macro_rules + | `(throwChangeOfVariablesError $msg:interpolatedStr) => + `(throwError ("`change_of_variables` error: " ++ (m! $msg))) + | `(throwChangeOfVariablesError $msg:term) => + `(throwError ("`change_of_variables` error: " ++ $msg)) + +/-- Throws an error coming from the `conv_opt` tactic. -/ +syntax "throwConvOptError " (interpolatedStr(term) <|> term) : term + +macro_rules + | `(throwConvOptError $msg:interpolatedStr) => `(throwError ("`conv_opt` error: " ++ (m! $msg))) + | `(throwConvOptError $msg:term) => `(throwError ("`conv_opt` error: " ++ $msg)) + +/-- Throws an error coming from the `conv_obj` tactic. -/ +syntax "throwConvObjError " (interpolatedStr(term) <|> term) : term + +macro_rules + | `(throwConvObjError $msg:interpolatedStr) => `(throwError ("`conv_obj` error: " ++ (m! $msg))) + | `(throwConvObjError $msg:term) => `(throwError ("`conv_obj` error: " ++ $msg)) + +/-- Throws an error coming from the `conv_constr` tactic. -/ +syntax "throwConvConstrError " (interpolatedStr(term) <|> term) : term + +macro_rules + | `(throwConvConstrError $msg:interpolatedStr) => + `(throwError ("`conv_constr` error: " ++ (m! $msg))) + | `(throwConvConstrError $msg:term) => `(throwError ("`conv_constr` error: " ++ $msg)) + +/-- Throws an error coming from the `remove_constrs` tactic. -/ +syntax "throwRemoveConstrError " (interpolatedStr(term) <|> term) : term + +macro_rules + | `(throwRemoveConstrError $msg:interpolatedStr) => + `(throwError ("`remove_constr` error: " ++ (m! $msg))) + | `(throwRemoveConstrError $msg:term) => `(throwError ("`remove_constr` error: " ++ $msg)) + +/-- Throws an error coming from the `remove_trivial_constrs` tactic. -/ +syntax "throwRemoveTrivialConstrsError " (interpolatedStr(term) <|> term) : term + +macro_rules + | `(throwRemoveTrivialConstrsError $msg:interpolatedStr) => + `(throwError ("`remove_trivial_constrs` error: " ++ (m! $msg))) + | `(throwRemoveTrivialConstrsError $msg:term) => + `(throwError ("`remove_trivial_constrs` error: " ++ $msg)) + +/-- Throws an error coming from the `rename_constrs` tactic. -/ +syntax "throwRenameConstrsError " (interpolatedStr(term) <|> term) : term + +macro_rules + | `(throwRenameConstrsError $msg:interpolatedStr) => + `(throwError ("`rename_constrs` error: " ++ (m! $msg))) + | `(throwRenameConstrsError $msg:term) => `(throwError ("`rename_constrs` error: " ++ $msg)) + +/-- Throws an error coming from the `rename_vars` tactic. -/ +syntax "throwRenameVarsError " (interpolatedStr(term) <|> term) : term + +macro_rules + | `(throwRenameVarsError $msg:interpolatedStr) => + `(throwError ("`rename_vars` error: " ++ (m! $msg))) + | `(throwRenameVarsError $msg:term) => `(throwError ("`rename_vars` error: " ++ $msg)) + +/-- Throws an error coming from the `reorder_constrs` tactic. -/ +syntax "throwReorderConstrsError " (interpolatedStr(term) <|> term) : term + +macro_rules + | `(throwReorderConstrsError $msg:interpolatedStr) => + `(throwError ("`reorder_constrs` error: " ++ (m! $msg))) + | `(throwReorderConstrsError $msg:term) => `(throwError ("`reorder_constrs` error: " ++ $msg)) + +/-- Throws an error coming from the `rw_constr` tactic. -/ syntax "throwRwConstrError " (interpolatedStr(term) <|> term) : term macro_rules | `(throwRwConstrError $msg:interpolatedStr) => `(throwError ("`rw_constr` error: " ++ (m! $msg))) | `(throwRwConstrError $msg:term) => `(throwError ("`rw_constr` error: " ++ $msg)) +/-- Throws an error coming from the `rw_obj` tactic. -/ syntax "throwRwObjError " (interpolatedStr(term) <|> term) : term macro_rules | `(throwRwObjError $msg:interpolatedStr) => `(throwError ("`rw_obj` error: " ++ (m! $msg))) | `(throwRwObjError $msg:term) => `(throwError ("`rw_obj` error: " ++ $msg)) +end BasicTactics + +section PreDCP + +/-- Throws an error coming from the `pre_dcp` tactic or any of the auxiliary functions. -/ syntax "throwPreDCPError " (interpolatedStr(term) <|> term) : term macro_rules | `(throwPreDCPError $msg:interpolatedStr) => `(throwError ("`pre_dcp` error: " ++ (m! $msg))) | `(throwPreDCPError $msg:term) => `(throwError ("`pre_dcp` error: " ++ $msg)) +end PreDCP + +section DCP + +/-- Throws an error coming from the `dcp` tactic or any of the auxiliary functions. -/ syntax "throwDCPError " (interpolatedStr(term) <|> term) : term macro_rules | `(throwDCPError $msg:interpolatedStr) => `(throwError ("`dcp` error: " ++ (m! $msg))) | `(throwDCPError $msg:term) => `(throwError ("`dcp` error: " ++ $msg)) +end DCP + +section Solve + +/-- Throws an error coming from the coefficient extraction procedure +(`CvxLean/Command/Solve/Float/Coeffs.lean`). -/ syntax "throwCoeffsError " (interpolatedStr(term) <|> term) : term macro_rules - | `(throwCoeffsError $msg:interpolatedStr) => `(throwError ("`coeffs` error: " ++ (m! $msg))) - | `(throwCoeffsError $msg:term) => `(throwError ("`coeffs` error: " ++ $msg)) + | `(throwCoeffsError $msg:interpolatedStr) => + `(throwError ("Coefficient extraction error: " ++ (m! $msg))) + | `(throwCoeffsError $msg:term) => `(throwError ("Coefficient extraction error: " ++ $msg)) +/-- Throws an error coming from the real-to-float procedure +(`CvxLean/Command/Solve/Float/RealToFloatCmd.lean`). -/ syntax "throwRealToFloatError " (interpolatedStr(term) <|> term) : term macro_rules | `(throwRealToFloatError $msg:interpolatedStr) => - `(throwError ("`real-to-float` error: " ++ (m! $msg))) - | `(throwRealToFloatError $msg:term) => `(throwError ("`real-to-float` error: " ++ $msg)) + `(throwError ("Real-to-float error: " ++ (m! $msg))) + | `(throwRealToFloatError $msg:term) => `(throwError ("Real-to-float error: " ++ $msg)) +/-- Throws an error coming from the `solve` command. -/ syntax "throwSolveError " (interpolatedStr(term) <|> term) : term macro_rules | `(throwSolveError $msg:interpolatedStr) => `(throwError ("`solve` error: " ++ (m! $msg))) | `(throwSolveError $msg:term) => `(throwError ("`solve` error: " ++ $msg)) -/-- Errors in the `equivalence` command. -/ +end Solve + +section Equivalence + +/-- Throws an error coming from the `equivalence` command. -/ syntax "throwEquivalenceError " (interpolatedStr(term) <|> term) : term macro_rules @@ -68,7 +187,11 @@ macro_rules `(throwError ("`equivalence` error: " ++ (m! $msg))) | `(throwEquivalenceError $msg:term) => `(throwError ("`equivalence` error: " ++ $msg)) -/-- Errors in the `reduction` command. -/ +end Equivalence + +section Reduction + +/-- Throws an error coming from the `reduction` command. -/ syntax "throwReductionError " (interpolatedStr(term) <|> term) : term macro_rules @@ -76,10 +199,14 @@ macro_rules `(throwError ("`reduction` error: " ++ (m! $msg))) | `(throwReductionError $msg:term) => `(throwError ("`reduction` error: " ++ $msg)) -/-- Errors in the `relaxation` command. -/ +section Relaxation + +/-- Throws an error coming from the `relaxation` command. -/ syntax "throwRelaxationError " (interpolatedStr(term) <|> term) : term macro_rules | `(throwRelaxationError $msg:interpolatedStr) => `(throwError ("`relaxation` error: " ++ (m! $msg))) | `(throwRelaxationError $msg:term) => `(throwError ("`relaxation` error: " ++ $msg)) + +end Relaxation diff --git a/CvxLean/Syntax/Minimization.lean b/CvxLean/Syntax/Minimization.lean index 1eb049ca..f25ba123 100644 --- a/CvxLean/Syntax/Minimization.lean +++ b/CvxLean/Syntax/Minimization.lean @@ -29,7 +29,7 @@ The full version with labels would be (fun p => 40 * p.1 + 30 * p.2) (fun p => {** 12 ≤ p.1 + p.2 ** `c₁ **} ∧ {** 16 ≤ 2 * p.1 + p.2 ** `c₂ **}) ``` -where the `{ ** e ** n }` notation is as defined in `Syntax/Label.lean`. +where the `{ ** e ** n }` notation is as defined in `CvxLean/Syntax/Label.lean`. We also define how to delaborate it so that it can be pretty printed back to the custom syntax. -/ diff --git a/CvxLean/Syntax/Parser.lean b/CvxLean/Syntax/Parser.lean index 36d0a9ce..552970cf 100644 --- a/CvxLean/Syntax/Parser.lean +++ b/CvxLean/Syntax/Parser.lean @@ -2,7 +2,7 @@ import Lean /-! This file defines how to parse variables, constraints, objective functions and full optimization -problems. Syntax matching these definitions is elaborated in `Syntax/Minimization.lean`. +problems. Syntax matching these definitions is elaborated in `CvxLean/Syntax/Minimization.lean`. -/ namespace CvxLean diff --git a/CvxLean/Tactic/Arith/Arith.lean b/CvxLean/Tactic/Arith/Arith.lean index 21611da6..a7b5a8de 100644 --- a/CvxLean/Tactic/Arith/Arith.lean +++ b/CvxLean/Tactic/Arith/Arith.lean @@ -5,6 +5,13 @@ import Mathlib.Tactic.Positivity import CvxLean.Tactic.Arith.NormNumVariants import CvxLean.Tactic.Arith.PositivityExt +/-! +# Arithmetic tactics + +This file defines the `positivity!` and `arith` tactics, which are slight extensions of the existing +arithmetic tactics in mathlib (`positivity` and `linarith`). +-/ + namespace Tactic open Lean Meta Elab Tactic Qq @@ -14,6 +21,11 @@ elab (name := cases_and) "cases_and" : tactic => do let mvarId' ← mvarId.casesAnd replaceMainGoal [mvarId'] +/-- Positivity can only work with goals and hypotheses of the form `0 R x`, where `R ∈ {<, ≤, ≠}`. +Our positivity extension is designed to admit hypotheses of the form `a R x` and `x R a` where +`a ≠ 0`. For example, `1 ≤ x → 0 ≤ log x`. To be able to use `positivity`, we rewrite `1 ≤ x` into +`0 ≤ x - 1`. This procedure does exactly that, changing goals and hypotheses to make them compatible +with `positivity`. -/ def preparePositivity (mvarId : MVarId) : MetaM MVarId := do mvarId.withContext do -- Adjust hypotheses if needed. @@ -60,7 +72,7 @@ def preparePositivity (mvarId : MVarId) : MetaM MVarId := do mvarId := g break else - throwError "prepare_positivity failed" + throwError "`prepare_positivity` failed." | none => continue return mvarId @@ -72,8 +84,8 @@ elab (name := prepare_positivity) "prepare_positivity" : tactic => do open Mathlib.Meta.Positivity -/-- Call `positivity` but if the expression has no free variables, we try to -apply `norm_num` first. -/ +/-- Call `positivity` but if the expression has no free variables, we try to apply `norm_num` +first. -/ def positivityMaybeNormNum : TacticM Unit := withMainContext do let g ← getMainGoal @@ -84,19 +96,21 @@ def positivityMaybeNormNum : TacticM Unit := `(tactic| (try { norm_num } <;> positivity)) else `(tactic| positivity) - let [] ← evalTacticAt tac g | throwError "positivity_maybe_norm_num failed" + let [] ← evalTacticAt tac g | throwError "`positivity_maybe_norm_num` failed." elab (name := positivity) "positivity_maybe_norm_num" : tactic => positivityMaybeNormNum end Tactic +/-- Extension of `positivity` with some pre-processing and the option to call `norm_num`. -/ syntax "positivity!" : tactic macro_rules | `(tactic| positivity!) => `(tactic| intros; cases_and; prepare_positivity; positivity_maybe_norm_num) +/-- Combination of tactics. We try `positivity!`, then `linarith`, then `norm_num`, then `simp`. -/ syntax "arith" : tactic macro_rules diff --git a/CvxLean/Tactic/Arith/NormNumVariants.lean b/CvxLean/Tactic/Arith/NormNumVariants.lean index 8b3e902e..b9d7e05d 100644 --- a/CvxLean/Tactic/Arith/NormNumVariants.lean +++ b/CvxLean/Tactic/Arith/NormNumVariants.lean @@ -2,6 +2,10 @@ import Lean import Mathlib.Tactic.NormNum import CvxLean.Lib.Math.Data.Real +/-! +Custom `norm_num` variants. +-/ + namespace CvxLean open Lean Elab Meta Tactic diff --git a/CvxLean/Tactic/Arith/PositivityExt.lean b/CvxLean/Tactic/Arith/PositivityExt.lean index 2a13630e..668a18eb 100644 --- a/CvxLean/Tactic/Arith/PositivityExt.lean +++ b/CvxLean/Tactic/Arith/PositivityExt.lean @@ -4,6 +4,12 @@ import Mathlib.Data.Real.Basic import Mathlib.Tactic.Positivity import CvxLean.Lib.Math.Data.Real +/-! +Ad-hoc positivity extensions for particular examples. If `positivity!` fails, then we can try to add +a new lemma here. This approach is not scalable, but is useful as long as we don't have a general +nonlinear arithmetic tactic. +-/ + namespace Mathlib.Meta.Positivity open Lean.Meta Qq Real diff --git a/CvxLean/Tactic/Basic/ChangeOfVariables.lean b/CvxLean/Tactic/Basic/ChangeOfVariables.lean index c7379a9b..58c84c8d 100644 --- a/CvxLean/Tactic/Basic/ChangeOfVariables.lean +++ b/CvxLean/Tactic/Basic/ChangeOfVariables.lean @@ -10,19 +10,45 @@ import CvxLean.Tactic.Basic.ConvOpt /-! # Change of variables + +This file defines the `change_of_variables` tactic. The idea is that change of variable functions +are registered as instances of the `ChangeOfVariables` typeclass. These instances are also currently +defined in this file. The tactic takes care of renaming the variables as needed, finding the change +of variables, and proving the condition. + +We illustrate it with an example of how to use it inside the `equivalence` command: +``` +equivalence eqv/q : + optimization (x : ℝ) + minimize x + subject to + c : 0 < x := by + change_of_variables (u) (x ↦ exp u) +``` +The resulting problem `q` looks as follows: +``` +optimization (u : ℝ) + minimize exp u + subject to + c : 0 < exp u := by +``` +We provide `change_of_variables!` as a convenient variant that also tries to remove trivial +constraints. In this case, it would remove `0 < exp u` as it is always true. -/ namespace CvxLean open Minimization -/-- -/ +/-- Change of variables functions are registered using this typeclass. The only requirement +(`property`) is that they have a right inverse `inv` over some domain given by `condition`. -/ class ChangeOfVariables {D E} (c : E → D) where inv : D → E condition : D → Prop property : ∀ x, condition x → c (inv x) = x -/-- -/ +/-- A change of variables is an equivalence-preserving transformation, as long as the condition +holds in the feasible set. -/ @[equiv] def ChangeOfVariables.toEquivalence {D E R} [Preorder R] {f : D → R} {cs : D → Prop} (c : E → D) [cov : ChangeOfVariables c] (h : ∀ x, cs x → cov.condition x) : @@ -37,27 +63,28 @@ def ChangeOfVariables.toEquivalence {D E R} [Preorder R] {f : D → R} {cs : D section Structural -/-- -/ +/-- Composing changes of variables results in a valid change of variables. -/ instance ChangeOfVariables.comp {D E F} (c₁ : E → D) (c₂ : F → E) [cov₁ : ChangeOfVariables c₁] [cov₂ : ChangeOfVariables c₂] : ChangeOfVariables (c₁ ∘ c₂) := { inv := cov₂.inv ∘ cov₁.inv condition := fun x => cov₁.condition x ∧ cov₂.condition (cov₁.inv x) property := fun x ⟨hx₁, hx₂⟩ => by simp [cov₂.property (cov₁.inv x) hx₂, cov₁.property x hx₁] } -/-- -/ +/-- Apply change of variables to the left of a product domain. -/ instance ChangeOfVariables.prod_left {D E F} (c : E → D) [cov : ChangeOfVariables c] : ChangeOfVariables (fun x : E × F => (c x.1, x.2)) := { inv := fun ⟨x₁, x₂⟩ => (cov.inv x₁, x₂) condition := fun ⟨x₁, _⟩ => cov.condition x₁ property := fun ⟨x₁, x₂⟩ hx => by simp [cov.property x₁ hx] } -/-- -/ +/-- Apply change of variables to the right of a product domain. -/ instance ChangeOfVariables.prod_right {D E F} (c : E → D) [cov : ChangeOfVariables c] : ChangeOfVariables (fun x : F × E => (x.1, c x.2)) := { inv := fun ⟨x₁, x₂⟩ => (x₁, cov.inv x₂) condition := fun ⟨_, x₂⟩ => cov.condition x₂ property := fun ⟨x₁, x₂⟩ hx => by simp [cov.property x₂ hx] } +/-- The identity map is a valid change of variables. -/ instance ChangeOfVariables.id {D} : ChangeOfVariables (fun x : D => x) := { inv := fun x => x condition := fun _ => True @@ -79,10 +106,10 @@ instance : ChangeOfVariables sqrt := condition := fun x => 0 ≤ x property := fun _ hx => sqrt_sq hx } --- NOTE(RFM): x ≠ 0 is technically not necessary as division is defined on all of ℝ, but we want to --- avoid division by zero. instance : ChangeOfVariables (fun x : ℝ => x⁻¹) := { inv := fun x => x⁻¹ + -- NOTE: x ≠ 0 is technically not necessary as `0⁻¹⁻¹ = 0`, but we want to explicitly always + -- avoid division by zero. condition := fun x => x ≠ 0 property := fun x _ => by field_simp } @@ -101,6 +128,11 @@ instance {a : ℝ} : ChangeOfVariables (fun x : ℝ => x + a) := condition := fun _ => True property := fun _ _ => by ring } +instance {a : ℕ} : ChangeOfVariables (fun x : ℝ => x + a) := + { inv := fun x => x - a + condition := fun _ => True + property := fun _ _ => by ring } + end RealInstances noncomputable section VecInstances @@ -110,30 +142,24 @@ instance {n : ℕ} {a : Fin n → ℝ} : ChangeOfVariables (fun (v : Fin n → condition := fun v => ∀ i, v i ≠ 0 ∧ a i ≠ 0 property := fun v hnot0 => by ext i - simp [←div_mul, div_self (hnot0 i).2, one_mul] } + simp [← div_mul, div_self (hnot0 i).2, one_mul] } end VecInstances -/- -The idea here is to have a tactic +open Lean Elab Meta Tactic Term - change_of_variables (u) (x ↦ e^u) +namespace Meta +/-- This defines the tactic `change_of_variables (u) (x ↦ exp u)`. It follows the following steps: 1. Detect exactly where to apply the change of variables. 2. Syntesize the instance. 2. Prove the conditions. -3. Apply the c-of-v to equivalence theorem. - -For now, it only works with real variables. --/ - -open Lean Elab Meta Tactic Term - -namespace Meta +3. Apply the change of variables to equivalence theorem. +For now, it only works with real variables.-/ def changeOfVariablesBuilder (newVarStx varToChangeStx : TSyntax `ident) - (changeStx : TSyntax `term) : EquivalenceBuilder := - fun eqvExpr g => do + (changeStx : TSyntax `term) : EquivalenceBuilder Unit := + fun eqvExpr g => g.withContext do let newVar := newVarStx.getId let varToChange := varToChangeStx.getId @@ -153,7 +179,7 @@ def changeOfVariablesBuilder (newVarStx varToChangeStx : TSyntax `ident) -- Construct change of variables function. let fvars := Array.mk <| vars.map (fun ⟨n, _⟩ => mkFVar (FVarId.mk n)) -- u ↦ c(u) - let changeFnStx ← `(fun $newVarStx => $changeStx) + let changeFnStx ← `(fun ($newVarStx : ℝ) => $changeStx) let changeFn ← Tactic.elabTerm changeFnStx none -- c(x) let changeTerm ← Core.betaReduce <| @@ -197,10 +223,11 @@ def changeOfVariablesBuilder (newVarStx varToChangeStx : TSyntax `ident) -- Solve change of variables condition. let gCondition := gsAfterApply[0]! - let (_, gCondition) ← gCondition.intros + let (_, gCondition) ← gCondition.introN 2 [`x, `h] let gsFinal ← evalTacticAt - (← `(tactic| simp [ChangeOfVariables.condition] <;> arith)) gCondition + (← `(tactic| (simp [ChangeOfVariables.condition] at * <;> arith))) gCondition if gsFinal.length != 0 then + trace[CvxLean.debug] "Could not prove {gsFinal}." throwError "Failed to solve change of variables condition." end Meta @@ -218,6 +245,7 @@ def evalChangeOfVariables : Tactic := fun stx => match stx with saveTacticInfoForToken stx | _ => throwUnsupportedSyntax +/-- Tries to remove trivial constraints after applying the change of variables. -/ syntax (name := changeOfVariablesAndRemove) "change_of_variables!" "(" ident ")" "(" ident "↦" term ")" : tactic diff --git a/CvxLean/Tactic/Basic/CleanUpComp.lean b/CvxLean/Tactic/Basic/CleanUpComp.lean index 3ce1202a..2187e6c9 100644 --- a/CvxLean/Tactic/Basic/CleanUpComp.lean +++ b/CvxLean/Tactic/Basic/CleanUpComp.lean @@ -1,6 +1,14 @@ import CvxLean.Meta.Minimization import CvxLean.Meta.Equivalence import CvxLean.Meta.TacticBuilder +import CvxLean.Meta.Util.Error + +/-! +# Clean up compositions + +Simple equivalence-preserving tactic that removes occurrences of `∘` (`Function.comp`) from the +objective function and constraints. These may appear after a map is applied to the problem. +-/ namespace CvxLean @@ -8,7 +16,7 @@ open Lean Meta Elab Tactic namespace Meta -/-- -/ +/-- Pre-process objective functions and constraints that have resulted from composing a function. -/ def cleanUpCompAux (e : Expr) (name : String) : MetaM Expr := do match e with | .app (.app (.app (.app (.app (.const ``Function.comp lvls) ty₁) ty₂) ty₃) f₁) f₂ => do @@ -17,8 +25,9 @@ def cleanUpCompAux (e : Expr) (name : String) : MetaM Expr := do return mkApp5 (mkConst ``Function.comp lvls) ty₁ ty₂ ty₃ f₁ f₂ | _ => throwError "{name} not of the form '... ∘ ...'" -/-- -/ -def cleanUpCompBuilder : EquivalenceBuilder := fun eqvExpr g => g.withContext do +/-- Trivially builds an equivalence by simplifying `Function.comp` on the problem, essentially, +pushing all function compositions. -/ +def cleanUpCompBuilder : EquivalenceBuilder Unit := fun eqvExpr g => g.withContext do let lhsMinExpr ← eqvExpr.toMinimizationExprLHS let newObjFun ← cleanUpCompAux lhsMinExpr.objFun "objFun" let newConstraints ← cleanUpCompAux lhsMinExpr.constraints "constr" @@ -36,11 +45,12 @@ end Meta namespace Tactic -open Lean.Elab Lean.Elab.Tactic +open Elab Tactic +/-- Tactic that removes `∘` from the optimization problem by applying function compositions (see +`Function.comp_apply`). -/ syntax (name := cleanUpComp) "clean_up_comp" : tactic -/-- -/ @[tactic cleanUpComp] def evalCleanUpComp : Tactic := fun stx => match stx with | `(tactic| clean_up_comp) => do diff --git a/CvxLean/Tactic/Basic/ConvOpt.lean b/CvxLean/Tactic/Basic/ConvOpt.lean index c7e1fbd1..e5f18f4e 100644 --- a/CvxLean/Tactic/Basic/ConvOpt.lean +++ b/CvxLean/Tactic/Basic/ConvOpt.lean @@ -3,6 +3,13 @@ import CvxLean.Meta.Equivalence import CvxLean.Meta.TacticBuilder import CvxLean.Tactic.Basic.ShowVars +/-! +# Conversion mode + +This file defines the `conv_opt`, `conv_obj`, and `conv_constr` tactics. These tactics are used to +travel to the the component of the optimization problem that we want to modify. +-/ + namespace CvxLean open Lean @@ -11,18 +18,24 @@ namespace Tactic.Conv open Meta Elab Parser Tactic Conv +/-- Convert whole optimization problem. An optimization problem `p` may appear in the goal in the +following forms: `⊢ Solution p`, `⊢ p ≡ q`, `⊢ p ≼ q`, or `⊢ p ≽' q`. Applying `conv_opt` travels +down to `p` and turns the goal into `| p`. -/ syntax (name := convOpt) "conv_opt" "=>" (convSeq)? : tactic +/-- Following from the comment on `conv_opt`, if `p := ⟨f, cs⟩`, turn the goal into `| f(x)`. -/ syntax (name := convObj) "conv_obj" "=>" (convSeq)? : tactic +/-- Following from the comment on `conv_opt`, if `p := ⟨f, cs⟩`, and `cs(x) := cs₁(x) ∧ ⋯ ∧ csₙ(x)`, +find constraint `csᵢ` with the given tag and turn the goal into `| csᵢ(x)`. -/ syntax (name := convConstr) "conv_constr" (ident)? "=>" (convSeq)? : tactic /-- Wrapper function to enter conv mode on an optimization problem. * If `fullProb` is set, then we enter conv mode on the whole problem. * If `changeObjFun` is set, then we enter conv mode on the objective function. -* Otherwise, we enter conv mode on the constraints. -/ +* Otherwise, we enter `conv` mode on the constraints. -/ def convertOpt (fullProb changeObjFun : Bool := false) (convTac : TacticM Unit) : - EquivalenceBuilder := + EquivalenceBuilder Unit := fun _ g => g.withContext do -- Turn into equality goal. if let [gEq] ← g.apply (mkConst ``Minimization.Equivalence.ofEq) then @@ -38,64 +51,64 @@ def convertOpt (fullProb changeObjFun : Bool := false) (convTac : TacticM Unit) let gToChange := if changeObjFun then objFunGoal else constrsGoal if let _ :: _ ← evalTacticAt (← `(tactic| rfl)) gToIgnore then let failureLocation := if changeObjFun then "constraints" else "objective function" - throwError "`conv_opt` error: failed to close {failureLocation} subgoals." + throwConvOptError "failed to close {failureLocation} subgoals." else -- Introduce optimization variables. if let [gToConv] ← evalTacticAt (← `(conv| ext $(mkIdent `p))) gToChange then if let [gToConv] ← evalTacticAt (← `(conv| show_vars $(mkIdent `p))) gToConv then pure gToConv else - throwError "`conv_opt` error: failed to show optimization variables." + throwConvOptError "failed to show optimization variables." else - throwError "`conv_opt` error: failed to introduce optimization variables." + throwConvOptError "failed to introduce optimization variables." else - throwError "`conv_opt` error: unexpected goal type." + throwConvOptError "unexpected goal type." - -- Apply conv tactic. + -- Apply `conv` tactic. for mvarId in ← Tactic.run gToConv convTac do liftM <| mvarId.refl <|> mvarId.inferInstance <|> pure () else - throwError "`conv_opt` error: could not convert to equality goal." + throwConvOptError "could not convert to equality goal." -/-- Enter conv mode on the full problem. The `shouldEval` flag is set to false when no tactics are +/-- Enter `conv` mode on the full problem. The `shouldEval` flag is set to false when no tactics are applied but we still want to enter conv mode and see the goal. -/ -def convertFullProb (shouldEval : Bool) (stx : Syntax) : EquivalenceBuilder := +def convertFullProb (shouldEval : Bool) (stx : Syntax) : EquivalenceBuilder Unit := convertOpt (fullProb := true) (changeObjFun := false) do if shouldEval then evalTactic stx else saveTacticInfoForToken stx @[tactic convOpt] partial def evalConvOpt : Tactic := fun stx => match stx with | `(tactic| conv_opt => $code) => (convertFullProb true code).toTactic - -- Avoid errors on empty conv block. + -- Avoid errors on empty `conv` block. | `(tactic| conv_opt =>) => do (convertFullProb false stx).toTactic | _ => throwUnsupportedSyntax section ConvObj -/-- Enter conv mode on the objective function. -/ -def convertObj (shouldEval : Bool) (stx : Syntax) : EquivalenceBuilder := +/-- Enter `conv` mode on the objective function. -/ +def convertObj (shouldEval : Bool) (stx : Syntax) : EquivalenceBuilder Unit := convertOpt (fullProb := false) (changeObjFun := true) do if shouldEval then evalTactic stx else saveTacticInfoForToken stx @[tactic convObj] partial def evalConvObj : Tactic := fun stx => match stx with | `(tactic| conv_obj => $code) => (convertObj true code).toTactic - -- Avoid errors on empty conv block. - | `(tactic| conv_obj =>) => do (convertObj false stx).toTactic + -- Avoid errors on empty `conv` block. + | `(tactic| conv_obj =>) => (convertObj false stx).toTactic | _ => throwUnsupportedSyntax end ConvObj section ConvConstr -/-- Split ands in conv goal. -/ +/-- Split ands in `conv` goal. -/ partial def splitAnds (goal : MVarId) : TacticM (List MVarId) := do let lhs := (← getLhsRhsCore goal).1 match lhs with | .app (.app (.const ``And _) p) _q => do let (mp, mq) ← match List.filterMap id <| ← Conv.congr <| goal with | [mp, mq] => pure (mp, mq) - | _ => throwError "`conv_constr` error: unexpected number of subgoals in `splitAnds`." + | _ => throwConvConstrError "unexpected number of subgoals in `splitAnds`." let tag ← getLabelName p mp.setTag tag return mp :: (← splitAnds mq) @@ -104,14 +117,14 @@ partial def splitAnds (goal : MVarId) : TacticM (List MVarId) := do goal.setTag tag return [goal] -/-- Enter conv mode setting all constraints as subgoals. -/ -def convertConstrs (shouldEval : Bool) (stx : Syntax) : EquivalenceBuilder := +/-- Enter `conv` mode setting all constraints as subgoals. -/ +def convertConstrs (shouldEval : Bool) (stx : Syntax) : EquivalenceBuilder Unit := convertOpt (fullProb := false) (changeObjFun := false) do replaceMainGoal <| ← splitAnds <| ← getMainGoal if shouldEval then evalTactic stx else saveTacticInfoForToken stx -/-- Enter conv mode on a specific constraint. -/ -def convertConstrWithName (shouldEval : Bool) (stx : Syntax) (h : Name) : EquivalenceBuilder := +/-- Enter `conv` mode on a specific constraint. -/ +def convertConstrWithName (shouldEval : Bool) (stx : Syntax) (h : Name) : EquivalenceBuilder Unit := convertOpt (fullProb := false) (changeObjFun := false) do let constrs ← splitAnds <| ← getMainGoal let mut found := false @@ -131,17 +144,17 @@ def convertConstrWithName (shouldEval : Bool) (stx : Syntax) (h : Name) : Equiva else let gs ← evalTacticAt (← `(tactic| rfl)).raw constr if !gs.isEmpty then - throwError "`conv_constr` error: failed to close all subgoals." + throwConvConstrError "failed to close all subgoals." if !found then - throwError "`conv_constr` error: no constraint with name {h} found." + throwConvConstrError "no constraint with name {h} found." @[tactic convConstr] partial def evalConvConstr : Tactic := fun stx => match stx with | `(tactic| conv_constr => $code) => (convertConstrs true code).toTactic | `(tactic| conv_constr $h => $code) => (convertConstrWithName true code h.getId).toTactic - -- Avoid errors on empty conv block. - | `(tactic| conv_constr =>) => do (convertConstrs false stx).toTactic - | `(tactic| conv_constr $h =>) => do (convertConstrWithName false stx h.getId).toTactic + -- Avoid errors on empty `conv` block. + | `(tactic| conv_constr =>) => (convertConstrs false stx).toTactic + | `(tactic| conv_constr $h =>) => (convertConstrWithName false stx h.getId).toTactic | _ => throwUnsupportedSyntax end ConvConstr diff --git a/CvxLean/Tactic/Basic/NormNumOpt.lean b/CvxLean/Tactic/Basic/NormNumOpt.lean index a3933f83..c42ee7d1 100644 --- a/CvxLean/Tactic/Basic/NormNumOpt.lean +++ b/CvxLean/Tactic/Basic/NormNumOpt.lean @@ -1,9 +1,12 @@ import CvxLean.Tactic.Basic.ConvOpt /-! -# Simple tactic to +# Simple tactic to normalize numerical constants -This is used, for example, in `pre_dcp` so that numbers are in the right format. +Applying `norm_num` directly is possible, but makes us lose control over how the chain of relations +is built. By calling it in `conv` mode, we make sure that the normalization happens behind the +`Equivalence.ofEq` constructor, which ensures that we can always obtain a computable backward map, +as creating the map relies on simplifying the chain of relations. -/ namespace CvxLean @@ -12,7 +15,7 @@ open Lean Elab Meta Tactic namespace Meta -def normNumOptBuilder : EquivalenceBuilder := +def normNumOptBuilder : EquivalenceBuilder Unit := Conv.convertOpt (fullProb := true) (convTac := Mathlib.Tactic.elabNormNum1Conv mkNullNode) end Meta diff --git a/CvxLean/Tactic/Basic/RemoveConstr.lean b/CvxLean/Tactic/Basic/RemoveConstr.lean index e7b703ec..ba5ed613 100644 --- a/CvxLean/Tactic/Basic/RemoveConstr.lean +++ b/CvxLean/Tactic/Basic/RemoveConstr.lean @@ -37,7 +37,7 @@ def composeAndIntro (l : List Expr) : MetaM Expr := return mkApp4 (mkConst ``And.intro) eTy esTy e es /-- -/ -def removeConstrBuilder (id : Name) (proof : Syntax) : EquivalenceBuilder := fun eqvExpr g => +def removeConstrBuilder (id : Name) (proof : Syntax) : EquivalenceBuilder Unit := fun eqvExpr g => g.withContext do let lhsMinExpr ← eqvExpr.toMinimizationExprLHS let (idxToRemove, numConstrs, newConstrs, toShow) ← @@ -106,12 +106,13 @@ end Meta namespace Tactic -syntax (name := removeConstr) "remove_constr" ident term : tactic +/-- -/ +syntax (name := removeConstr) "remove_constr" ident " => " tacticSeq : tactic @[tactic removeConstr] partial def evalRemoveConstr : Tactic := fun stx => match stx with -| `(tactic| remove_constr $id $proof) => do - (removeConstrBuilder id.getId proof).toTactic +| `(tactic| remove_constr $id => $tacStx) => do + (removeConstrBuilder id.getId (← `(term| by $tacStx))).toTactic saveTacticInfoForToken stx | _ => throwUnsupportedSyntax diff --git a/CvxLean/Tactic/Basic/RemoveTrivialConstrs.lean b/CvxLean/Tactic/Basic/RemoveTrivialConstrs.lean index 43838ec3..ef555f1e 100644 --- a/CvxLean/Tactic/Basic/RemoveTrivialConstrs.lean +++ b/CvxLean/Tactic/Basic/RemoveTrivialConstrs.lean @@ -10,7 +10,7 @@ open Lean Meta Elab Term Tactic namespace Meta -def removeTrivialConstrsBuilder : EquivalenceBuilder := fun eqvExpr g => g.withContext do +def removeTrivialConstrsBuilder : EquivalenceBuilder Unit := fun eqvExpr g => g.withContext do let lhs ← eqvExpr.toMinimizationExprLHS let constrNames ← withLambdaBody lhs.constraints fun _ constrsBody => do let cs ← decomposeConstraints constrsBody diff --git a/CvxLean/Tactic/Basic/RenameConstrs.lean b/CvxLean/Tactic/Basic/RenameConstrs.lean index 3de5a34c..7fe9364c 100644 --- a/CvxLean/Tactic/Basic/RenameConstrs.lean +++ b/CvxLean/Tactic/Basic/RenameConstrs.lean @@ -14,7 +14,7 @@ def replaceConstrName (name : Name) (e : Expr) : MetaM Expr := do /-- Rename the constraints using `names`. The `names` Array can be shorter then the number of constraints; then only the first few constraints are renamed. -/ -def renameConstrsBuilder (names : Array Lean.Name) : EquivalenceBuilder := fun eqvExpr g => do +def renameConstrsBuilder (names : Array Name) : EquivalenceBuilder Unit := fun eqvExpr g => do let lhsMinExpr ← eqvExpr.toMinimizationExprLHS -- Create new renamed constraints. @@ -25,7 +25,7 @@ def renameConstrsBuilder (names : Array Lean.Name) : EquivalenceBuilder := fun e for i in [:names.size] do let newConstr ← replaceConstrName names[i]! constraints[i]! constraints := constraints.set! i newConstr - return ← mkLambdaFVars #[p] $ composeAnd constraints.toList + return ← mkLambdaFVars #[p] <| composeAnd constraints.toList let rhsMinExpr := { lhsMinExpr with constraints := newConstrs } if !(← isDefEq eqvExpr.rhs rhsMinExpr.toExpr) then diff --git a/CvxLean/Tactic/Basic/RenameVars.lean b/CvxLean/Tactic/Basic/RenameVars.lean index ea89a641..39053f01 100644 --- a/CvxLean/Tactic/Basic/RenameVars.lean +++ b/CvxLean/Tactic/Basic/RenameVars.lean @@ -9,7 +9,7 @@ open Lean Expr Meta Elab Tactic namespace Meta /-- -/ -def renameVarsBuilder (names : Array Lean.Name) : EquivalenceBuilder := fun eqvExpr g => do +def renameVarsBuilder (names : Array Name) : EquivalenceBuilder Unit := fun eqvExpr g => do let lhsMinExpr ← eqvExpr.toMinimizationExprLHS let vars ← decomposeDomain (← instantiateMVars eqvExpr.domainLHS) diff --git a/CvxLean/Tactic/Basic/ReorderConstrs.lean b/CvxLean/Tactic/Basic/ReorderConstrs.lean index 1f1eb553..6b5a9aee 100644 --- a/CvxLean/Tactic/Basic/ReorderConstrs.lean +++ b/CvxLean/Tactic/Basic/ReorderConstrs.lean @@ -27,7 +27,7 @@ def reorderConstrsExpr (ids : Array Name) (e : Expr) : MetaM Expr := do return composeAnd newConstrs.data /-- -/ -def reorderConstrsBuilder (ids : Array Name) : EquivalenceBuilder := fun eqvExpr g => +def reorderConstrsBuilder (ids : Array Name) : EquivalenceBuilder Unit := fun eqvExpr g => g.withContext do let lhsMinExpr ← eqvExpr.toMinimizationExprLHS let newConstrs ← withLambdaBody lhsMinExpr.constraints fun p constrBody => do diff --git a/CvxLean/Tactic/Basic/RewriteOpt.lean b/CvxLean/Tactic/Basic/RewriteOpt.lean index 11cf9154..9fa37d1e 100644 --- a/CvxLean/Tactic/Basic/RewriteOpt.lean +++ b/CvxLean/Tactic/Basic/RewriteOpt.lean @@ -4,13 +4,14 @@ import CvxLean.Meta.TacticBuilder import CvxLean.Meta.Util.Error import CvxLean.Meta.Util.Debug import CvxLean.Tactic.Basic.ShowVars +import CvxLean.Tactic.Basic.RenameConstrs /-! # Tactics for conditional rewrites TODO -## TODO +### TODO * Clean up. -/ @@ -36,8 +37,8 @@ def rewriteObjLemma (numConstrs : Nat) : MetaM Name := | _ => throwRwObjError "can only rewrite problems with up to 10 constraints." /-- -/ -def rewriteObjBuilder (shouldEval : Bool) (tacStx : Syntax) (rhs? : Option Expr) : - EquivalenceBuilder := +def rewriteObjBuilderFromTactic (shouldEval : Bool) (tac : MVarId → FVarId → TacticM Unit) + (rhs? : Option Expr) : EquivalenceBuilder Unit := fun eqvExpr g => g.withContext do let lhsMinExpr ← eqvExpr.toMinimizationExprLHS let constrTags ← withLambdaBody lhsMinExpr.constraints fun _ constrsBody => do @@ -59,12 +60,38 @@ def rewriteObjBuilder (shouldEval : Bool) (tacStx : Syntax) (rhs? : Option Expr) let (fvarIds, gAfterIntros) ← gToRw.introN (1 + numConstrs) ([`p] ++ constrTags) if fvarIds.size == 0 then throwRwObjError "could not introduce optimization variables." + let probFVarId := fvarIds[0]! let gAfterShowVars ← showVars gAfterIntros (fvarIds.get! 0) if shouldEval then - if let gs@(_ :: _) ← evalTacticAt tacStx gAfterShowVars then - trace[CvxLean.debug] "`rw_obj` could not close {gs} using {tacStx}." + if let gs@(_ :: _) ← Tactic.run gAfterShowVars (tac gAfterShowVars probFVarId) then + trace[CvxLean.debug] "`rw_obj` could not close {gs}." throwRwObjError "could not close all goals." +/-- -/ +def rewriteObjBuilder (shouldEval : Bool) (tacStx : Syntax) (rhs? : Option Expr) : + EquivalenceBuilder Unit := + fun eqvExpr g => do + if shouldEval then + trace[CvxLean.debug] "`rw_obj` using tactic {tacStx}." + let tac := fun g _ => evalTacticAt tacStx g >>= setGoals + rewriteObjBuilderFromTactic shouldEval tac rhs? eqvExpr g + +def rewriteObjBuilderWithoutTarget (shouldEval : Bool) (tacStx : Syntax) : + EquivalenceBuilder Unit := + rewriteObjBuilder shouldEval tacStx none + +def rewriteObjBuilderWithTarget (shouldEval : Bool) (tacStx : Syntax) (targetStx : Syntax) : + EquivalenceBuilder Unit := + fun eqvExpr g => g.withContext do + let lhsMinExpr ← eqvExpr.toMinimizationExprLHS + let vars ← decomposeDomain (← instantiateMVars lhsMinExpr.domain) + let fvars := Array.mk <| vars.map (fun ⟨n, _⟩ => mkFVar (FVarId.mk n)) + let target ← withLocalDeclD `p lhsMinExpr.domain fun p => do + Meta.withDomainLocalDecls lhsMinExpr.domain p fun xs prs => do + let target ← Tactic.elabTerm targetStx none + mkLambdaFVars #[p] ((target.replaceFVars fvars xs).replaceFVars xs prs) + rewriteObjBuilder shouldEval tacStx (some target) eqvExpr g + /-- Returns lemma to rewrite constraints at `rwIdx` and the name of the RHS parameter. -/ def rewriteConstrLemma (rwIdx : Nat) (numConstrs : Nat) : MetaM (Name × Name) := if rwIdx == numConstrs then @@ -124,8 +151,8 @@ end RIntro open Term /-- -/ -def rewriteConstrBuilder (shouldEval : Bool) (constrTag : Name) (tacStx : Syntax) - (rhs? : Option Expr) : EquivalenceBuilder := +def rewriteConstrBuilderFromTactic (shouldEval : Bool) (constrTag : Name) + (tac : MVarId → FVarId → TacticM Unit) (rhs? : Option Expr) : EquivalenceBuilder (List Name) := fun eqvExpr g => g.withContext do let lhsMinExpr ← eqvExpr.toMinimizationExprLHS let constrTags ← withLambdaBody lhsMinExpr.constraints fun _ constrsBody => do @@ -150,7 +177,7 @@ def rewriteConstrBuilder (shouldEval : Bool) (constrTag : Name) (tacStx : Syntax if tag == rhsName && rhs?.isSome then g.assign rhs?.get! if !foundGToRw then - throwRwConstrError "error: could not find rewrite goal." + throwRwConstrError "could not find rewrite goal." -- Intros appropriately. let constrTagsBefore := constrTags.take rwIdx.pred let numConstrsBefore := constrTagsBefore.length @@ -172,32 +199,68 @@ def rewriteConstrBuilder (shouldEval : Bool) (constrTag : Name) (tacStx : Syntax let gAfterShowVars ← showVars gAfterIntros probFVarId if shouldEval then - if let gs@(_ :: _) ← evalTacticAt tacStx gAfterShowVars then - trace[CvxLean.debug] "`rw_constr` could not close {gs} using {tacStx}." + if let gs@(_ :: _) ← Tactic.run gAfterShowVars (tac gAfterShowVars probFVarId) then + trace[CvxLean.debug] "`rw_constr` could not close {gs}." throwRwConstrError "could not close all goals." + return constrTags + +def rewriteConstrBuilder (shouldEval : Bool) (constrTag : Name) (tacStx : Syntax) + (rhs? : Option Expr) : EquivalenceBuilder (List Name) := + fun g eqvExpr => do + if shouldEval then + trace[CvxLean.debug] "`rw_constr` using tactic {tacStx}." + let tac := fun g _ => evalTacticAt tacStx g >>= setGoals + rewriteConstrBuilderFromTactic shouldEval constrTag tac rhs? g eqvExpr + +def rewriteConstrBuilderWithoutTarget (shouldEval : Bool) (constrTag : Name) (tacStx : Syntax) : + EquivalenceBuilder (List Name) := + rewriteConstrBuilder shouldEval constrTag tacStx none + +def rewriteConstrBuilderWithTarget (shouldEval : Bool) (constrTag : Name) (tacStx : Syntax) + (targetStx : Syntax) : EquivalenceBuilder (List Name) := + fun eqvExpr g => g.withContext do + let lhsMinExpr ← eqvExpr.toMinimizationExprLHS + let vars ← decomposeDomain (← instantiateMVars lhsMinExpr.domain) + let fvars := Array.mk <| vars.map (fun ⟨n, _⟩ => mkFVar (FVarId.mk n)) + let target ← withLocalDeclD `p lhsMinExpr.domain fun p => do + Meta.withDomainLocalDecls lhsMinExpr.domain p fun xs prs => do + let target ← Tactic.elabTerm targetStx none + mkLambdaFVars #[p] ((target.replaceFVars fvars xs).replaceFVars xs prs) + rewriteConstrBuilder shouldEval constrTag tacStx (some target) eqvExpr g + end Meta namespace Tactic -syntax (name := rwObj) "rw_obj" "=>" (tacticSeq)? : tactic +syntax (name := rwObj) "rw_obj" (&" into" term)? darrow (tacticSeq)? : tactic @[tactic rwObj] def evalRwObjFun : Tactic := fun stx => match stx with - | `(tactic|rw_obj => $tacStx) => do - (rewriteObjBuilder true tacStx none).toTactic | `(tactic|rw_obj =>) => do - (rewriteObjBuilder false stx none).toTactic + (rewriteObjBuilderWithoutTarget false stx).toTactic + | `(tactic|rw_obj => $tacStx) => do + (rewriteObjBuilderWithoutTarget true tacStx).toTactic + | `(tactic|rw_obj into $targetStx =>) => do + (rewriteObjBuilderWithTarget false stx targetStx).toTactic + | `(tactic|rw_obj into $targetStx => $tacStx) => do + (rewriteObjBuilderWithTarget true tacStx targetStx).toTactic | _ => throwUnsupportedSyntax -syntax (name := rwConstr) "rw_constr" (ident)? "=>" (tacticSeq)? : tactic +syntax (name := rwConstr) "rw_constr " (ident)? (&" into" term)? darrow (tacticSeq)? : tactic @[tactic rwConstr] def evalRwConstr : Tactic := fun stx => match stx with - | `(tactic|rw_constr $h => $tacStx) => do - (rewriteConstrBuilder true h.getId tacStx none).toTactic | `(tactic|rw_constr $h =>) => do - (rewriteConstrBuilder false h.getId stx none).toTactic + let _ ← (rewriteConstrBuilderWithoutTarget false h.getId stx).toTactic + | `(tactic|rw_constr $h => $tacStx) => do + let constrTags ← (rewriteConstrBuilderWithoutTarget true h.getId tacStx).toTactic + (renameConstrsBuilder constrTags.toArray).toTactic + | `(tactic|rw_constr $h into $targetStx =>) => do + let _ ← (rewriteConstrBuilderWithTarget false h.getId stx targetStx).toTactic + | `(tactic|rw_constr $h into $targetStx => $tacStx) => do + let constrTags ← (rewriteConstrBuilderWithTarget true h.getId tacStx targetStx).toTactic + (renameConstrsBuilder constrTags.toArray).toTactic | _ => throwUnsupportedSyntax end Tactic diff --git a/CvxLean/Tactic/DCP/AtomExt.lean b/CvxLean/Tactic/DCP/AtomExt.lean index fc1425ad..22a86623 100644 --- a/CvxLean/Tactic/DCP/AtomExt.lean +++ b/CvxLean/Tactic/DCP/AtomExt.lean @@ -130,115 +130,4 @@ def getAtomDiscrTree : MetaM (DiscrTree AtomData) := do end AtomData -section UnsafeAtomData - -/-- Version of `GraphAtomData` where the implementation of the atom is -processed dynamically. Instead of providing a concrete implementation, -meta-level code is given, which may use the internals of the atom library. -This is useful, for instance, for higher-order atoms. - -They are called unsafe because there is no assurance that they will be reduced -to conic form. The soundness of the system is not compromised as DCP is still -required to produce an equivalence. -/ -structure UnsafeGraphAtomData where - id : Name - curvature : Curvature - expr : Expr - argKinds : Array ArgKind - bconds : MetaM (Array (Lean.Name × Expr)) - vconds : MetaM (Array (Lean.Name × Expr)) - impVars : MetaM (Array (Lean.Name × Expr)) - impObjFun : MetaM Expr - impConstrs : MetaM (Array Expr) - solution : MetaM (Array Expr) - solEqAtom : MetaM Expr - feasibility : MetaM (Array Expr) - optimality : MetaM Expr - vcondElim : MetaM (Array Expr) -deriving Inhabited - -instance : BEq UnsafeGraphAtomData where - beq uga1 uga2 := - uga1.id == uga2.id && - uga1.curvature == uga2.curvature && - uga1.expr == uga2.expr && - uga1.argKinds == uga2.argKinds - -instance : ToMessageData UnsafeGraphAtomData where - toMessageData d := - m!"safe: false - id: {d.id} - curvature: {d.curvature} - expr: {d.expr} - argKinds: {d.argKinds}" - -/-- Execute the meta code to get `GraphAtomData` from `UnsafeGraphAtomData`. -/ -def UnsafeGraphAtomData.toGraphAtomData (uga : UnsafeGraphAtomData) : - MetaM GraphAtomData := do - return { - id := uga.id, - curvature := uga.curvature, - expr := uga.expr, - argKinds := uga.argKinds, - bconds := ← uga.bconds, - vconds := ← uga.vconds, - impVars := ← uga.impVars, - impObjFun := ← uga.impObjFun, - impConstrs := ← uga.impConstrs, - solution := ← uga.solution, - solEqAtom := ← uga.solEqAtom, - feasibility := ← uga.feasibility, - optimality := ← uga.optimality, - vcondElim := ← uga.vcondElim - } - -/-- Wrapper of `UnsafeGraphAtomData`. -/ -inductive UnsafeAtomData - | graph (d : UnsafeGraphAtomData) -deriving Inhabited, BEq - -instance : ToMessageData UnsafeAtomData where - toMessageData - | UnsafeAtomData.graph d => toMessageData d - -/-- Get the expression corresponding to the atom. -/ -def UnsafeAtomData.expr : UnsafeAtomData → Expr - | graph d => d.expr - -/-- Unwrap `AtomData`. -/ -def UnsafeAtomData.graph! : UnsafeAtomData → UnsafeGraphAtomData - | graph d => d - -/-- Environment extension to store atoms. -/ -def UnsafeAtomExtension : Type := - PersistentEnvExtension - (Array Key × UnsafeAtomData) (Array Key × UnsafeAtomData) (DiscrTree UnsafeAtomData) -deriving Inhabited - -initialize unsafeAtomExtension : UnsafeAtomExtension ← do - let atomExtension ← registerPersistentEnvExtension { - name := `unsafeAtomExtension - mkInitial := return {} - addImportedFn := fun as _ => - let addEntryFn : - DiscrTree UnsafeAtomData → Array Key × UnsafeAtomData → DiscrTree UnsafeAtomData := - fun s d => s.insertCore d.1 d.2 - return mkStateFromImportedEntries addEntryFn {} as, - addEntryFn := fun s d => s.insertCore d.1 d.2, - exportEntriesFn := fun s => s.toArray - } - return atomExtension - -/-- Add a new unsafe atom to the library. -/ -def addUnsafeAtom (data : UnsafeAtomData) : MetaM Unit := do - let (_, _, expr) ← lambdaMetaTelescope data.expr - let keys ← DiscrTree.mkPath expr - setEnv $ unsafeAtomExtension.addEntry (← getEnv) (keys, data) - -/-- Get the unsafe atoms tree. -/ -def getUnsafeAtomDiscrTree : MetaM (DiscrTree UnsafeAtomData) := do - return unsafeAtomExtension.getState (← getEnv) - -end UnsafeAtomData - end CvxLean diff --git a/CvxLean/Tactic/DCP/AtomLibrary/Fns/Norm.lean b/CvxLean/Tactic/DCP/AtomLibrary/Fns/Norm.lean index 707d3869..3a519656 100644 --- a/CvxLean/Tactic/DCP/AtomLibrary/Fns/Norm.lean +++ b/CvxLean/Tactic/DCP/AtomLibrary/Fns/Norm.lean @@ -3,6 +3,18 @@ import CvxLean.Tactic.DCP.AtomLibrary.Sets.Cones import CvxLean.Tactic.DCP.AtomLibrary.Fns.VecCons import CvxLean.Lib.Math.Data.Vec +/-! +### TODO + +This is not defined in full generality. It can be made increasing or decreasing in each `xᵢ` +depending on the sign of `xᵢ`. Only affine arguments are accepted for now. + +As a first step, we should define two cases for when `x ≥ 0` and `x ≤ 0`. The issue is that in the +optimality condition we do not assume that `x'` satisfies any conditions. So, for example, in the +nonnegative case, we cannot prove that `∑ i, (x' i) ^ 2 ≤ ∑ i, (x i) ^ 2` just from `x' ≤ x`. We +would need to know that `0 ≤ x'`. +-/ + namespace CvxLean open Real diff --git a/CvxLean/Tactic/DCP/AtomLibrary/Fns/Sum.lean b/CvxLean/Tactic/DCP/AtomLibrary/Fns/Sum.lean index 3a10b83c..5f6a4ea6 100644 --- a/CvxLean/Tactic/DCP/AtomLibrary/Fns/Sum.lean +++ b/CvxLean/Tactic/DCP/AtomLibrary/Fns/Sum.lean @@ -23,8 +23,23 @@ optimality by intros _ _ apply hx -open Matrix in +open BigOperators in +declare_atom Vec.sum' [affine] (m : Nat)& (x : Fin m → ℝ)+ : ∑ i, x i := +bconditions +homogenity by + simp only [Pi.zero_apply] + rw [Finset.smul_sum, Finset.sum_const_zero, add_zero, smul_zero, add_zero] + rfl +additivity by + simp only [Pi.zero_apply, Pi.add_apply] + rw [Finset.sum_const_zero, add_zero, Finset.sum_add_distrib] +optimality by + intro x' hx + apply Finset.sum_le_sum + intros _ _ + apply hx +open Matrix in declare_atom Matrix.sum [affine] (m : Nat)& (X : Matrix.{0,0,0} (Fin m) (Fin m) ℝ)+ : sum X := bconditions diff --git a/CvxLean/Tactic/DCP/AtomLibrary/Sets/Le.lean b/CvxLean/Tactic/DCP/AtomLibrary/Sets/Le.lean index 3f23d67b..0129c9a6 100644 --- a/CvxLean/Tactic/DCP/AtomLibrary/Sets/Le.lean +++ b/CvxLean/Tactic/DCP/AtomLibrary/Sets/Le.lean @@ -45,4 +45,29 @@ optimality by exact le.optimality _ _ _ _ (hx i) (hy i) (h i) vconditionElimination +declare_atom Vec.le' [convex_set] (n : Nat)& (x : (Fin n) → ℝ)- (y : (Fin n) → ℝ)+ : + ∀ i, x i ≤ y i := +vconditions +implementationVars +implementationObjective Real.Vec.posOrthCone (y - x : (Fin n) → ℝ) +implementationConstraints +solution +solutionEqualsAtom by + unfold Real.Vec.posOrthCone + unfold Real.posOrthCone + rw [← iff_iff_eq] + constructor + · intros h i + rw [← le.solEqAtom] + unfold Real.posOrthCone + apply h + · intros hy i + simp [hy i] +feasibility +optimality by + intros x' y' hx hy h i + unfold Real.Vec.posOrthCone at h + exact le.optimality _ _ _ _ (hx i) (hy i) (h i) +vconditionElimination + end CvxLean diff --git a/CvxLean/Tactic/DCP/DCP.lean b/CvxLean/Tactic/DCP/DCP.lean index 2c56ff71..e072682e 100644 --- a/CvxLean/Tactic/DCP/DCP.lean +++ b/CvxLean/Tactic/DCP/DCP.lean @@ -75,6 +75,7 @@ partial def findAtoms (e : Expr) (vars : Array FVarId) (curvature : Curvature) : Bool × Array MessageData × AtomDataTrees) := do + let e := eta e -- Variable case. if e.isFVar ∧ vars.contains e.fvarId! then trace[Meta.debug] "Variable {e}" @@ -978,7 +979,7 @@ def canonize (ogProblem : MinimizationExpr) : MetaM (MinimizationExpr × Expr) : return (redProblem, eqvProof) -def dcpBuilder : EquivalenceBuilder := fun eqvExpr g => g.withContext do +def dcpBuilder : EquivalenceBuilder Unit := fun eqvExpr g => g.withContext do let ogProblem ← eqvExpr.toMinimizationExprLHS let (_, eqvProof) ← canonize ogProblem if ! (← isDefEq (mkMVar g) eqvProof) then diff --git a/CvxLean/Tactic/PreDCP/Egg/FromMinimization.lean b/CvxLean/Tactic/PreDCP/Egg/FromMinimization.lean index 7ae8faab..b76431d8 100644 --- a/CvxLean/Tactic/PreDCP/Egg/FromMinimization.lean +++ b/CvxLean/Tactic/PreDCP/Egg/FromMinimization.lean @@ -82,7 +82,7 @@ partial def EggTree.adjustOps (t : Tree String String) : MetaM (Tree String Stri else return Tree.node newOp newChildren else - throwError s!"The atom {op} is not supported by the `convexify` tactic." + throwError s!"The atom {op} is not supported by the `pre_dcp` tactic." | Tree.leaf "unknown" => throwError "Unknown atom." | l => return l @@ -203,6 +203,7 @@ partial def EggTree.toEggString : Tree String String → String /-- Size of the AST. -/ partial def EggTree.size : EggTree → Nat + | Tree.node "var" _ => 1 | Tree.node _ children => 1 + (children.map EggTree.size).foldl Nat.add 0 | Tree.leaf _ => 1 diff --git a/CvxLean/Tactic/PreDCP/Egg/Runner.lean b/CvxLean/Tactic/PreDCP/Egg/Runner.lean index 497b5413..6bbc6cd9 100644 --- a/CvxLean/Tactic/PreDCP/Egg/Runner.lean +++ b/CvxLean/Tactic/PreDCP/Egg/Runner.lean @@ -126,16 +126,15 @@ def runEggRequestRaw (requestJson : String) : MetaM String := do def parseEggResponse (responseString : String) : MetaM (Array EggRewrite) := do dbg_trace s!"Egg response: {responseString}" let outJson : Json ← match Json.parse responseString with - | Except.error e => throwError (s!"Error calling egg. JSON parsing error ({e}). " - ++ "It might be an issue with parsing inequalities.") + | Except.error e => throwError (s!"error calling `egg`, JSON parsing error ({e}).") | Except.ok j => pure j let responseType := (outJson.getObjValD "response").getStr! if responseType == "Error" then match (outJson.getObjValD "error").getStr? with - | Except.error _ => throwError "Error calling egg." - | Except.ok errorMessage => throwError "Error calling egg. Error: {errorMessage}" + | Except.error _ => throwError "error calling `egg`." + | Except.ok errorMessage => throwError "from `egg` (\"{errorMessage}\")." else let steps ← liftExcept <| outJson.getObjVal? "steps" diff --git a/CvxLean/Tactic/PreDCP/Egg/ToExpr.lean b/CvxLean/Tactic/PreDCP/Egg/ToExpr.lean index 0a0d22be..f58dbf9a 100644 --- a/CvxLean/Tactic/PreDCP/Egg/ToExpr.lean +++ b/CvxLean/Tactic/PreDCP/Egg/ToExpr.lean @@ -1,5 +1,6 @@ import CvxLean.Lib.Math.Data.Real import CvxLean.Meta.Minimization +import CvxLean.Meta.Util.Error import CvxLean.Tactic.PreDCP.Egg.Sexp import CvxLean.Tactic.DCP.Tree @@ -32,6 +33,11 @@ partial def EggTree.toExpr (vars params : List String) : Tree String String → -- Pi. | Tree.leaf "pi" => do return mkConst ``Real.pi + -- Special case: hole. Used to apply `congrAgr` in `pre_dcp`. + | Tree.leaf "?" => do + if "subexpr" ∈ vars then + throwPreDCPError "\"subexpr\" is not a valid variable name." + return mkFVar (FVarId.mk (Name.mkSimple "subexpr")) -- Numbers. | Tree.leaf s => match Json.Parser.num s.mkIterator with @@ -54,13 +60,13 @@ partial def EggTree.toExpr (vars params : List String) : Tree String String → let simpResult ← Mathlib.Meta.NormNum.deriveSimp (ctx := ← Simp.Context.mkDefault) (e := expr) return simpResult.expr - | _ => throwError "`pre_dcp` tree to Expr conversion error: unexpected num {s}." + | _ => throwPreDCPError "conversion error, unexpected num {s}." -- Variables. | Tree.node "var" #[Tree.leaf s] => if s ∈ vars then return mkFVar (FVarId.mk (Name.mkSimple s)) else - throwError "`pre_dcp` tree to Expr conversion error: unexpected var {s}." + throwPreDCPError "conversion error, unexpected var {s}." -- Parameters. | Tree.node "param" #[Tree.leaf s] => if s ∈ params then @@ -176,7 +182,7 @@ partial def EggTree.toExpr (vars params : List String) : Tree String String → return Meta.mkLabel (Name.mkSimple s) t -- Error. | Tree.node op children => - throwError "Tree to Expr conversion error: unexpected op {op} with {children.size} children." + throwPreDCPError "conversion error, unexpected op {op} with {children.size} children." where mkRealHBinAppExpr (opName instHName : Name) (nTyArgs : Nat) (instName : Name) (e1 e2 : Expr) : Expr := diff --git a/CvxLean/Tactic/PreDCP/PreDCP.lean b/CvxLean/Tactic/PreDCP/PreDCP.lean index 7e3f365f..f7a22eca 100644 --- a/CvxLean/Tactic/PreDCP/PreDCP.lean +++ b/CvxLean/Tactic/PreDCP/PreDCP.lean @@ -3,15 +3,15 @@ import CvxLean.Meta.Equivalence import CvxLean.Meta.TacticBuilder import CvxLean.Meta.Util.Error import CvxLean.Meta.Util.Debug -import CvxLean.Tactic.PreDCP.RewriteMapExt -import CvxLean.Tactic.PreDCP.RewriteMapLibrary +import CvxLean.Tactic.PreDCP.RuleToTacticExt +import CvxLean.Tactic.PreDCP.RuleToTacticLibrary import CvxLean.Tactic.PreDCP.Egg.All import CvxLean.Tactic.Basic.ConvOpt import CvxLean.Tactic.Basic.NormNumOpt import CvxLean.Tactic.Basic.RewriteOpt /-! -# Atomatic transformation to DCP form +# Automatic transformation to DCP form This file defines the tactic `pre_dcp`, which calls `egg` to find a sequence of rewrites that turns an optimization problem into DCP form. @@ -24,7 +24,7 @@ open Lean Elab Meta Tactic Term IO /-- Given the rewrite name and direction from egg's output, find the appropriate tactic in the environment. It also returns a bool to indicate if the proof needs an intermediate equality step. Otherwise, the tactic will be applied directly. -/ -def findTactic (atObjFun : Bool) (rewriteName : String) (direction : EggRewriteDirection) : +def findTactic (rewriteName : String) (direction : EggRewriteDirection) : MetaM (TSyntax `tactic × Bool) := do match ← getTacticFromRewriteName rewriteName with | some (tac, mapObjFun) => @@ -35,16 +35,13 @@ def findTactic (atObjFun : Bool) (rewriteName : String) (direction : EggRewriteD | EggRewriteDirection.Forward => return (tac, false) | EggRewriteDirection.Backward => -- Simply flip the goal so that the rewrite is applied to the target. - if atObjFun then - return (← `(tactic| (rw [eq_comm]; $tac)), false) - else - return (← `(tactic| (rw [Iff.comm]; $tac)), false) + return (← `(tactic| (symm; $tac)), false) | _ => throwPreDCPError "unknown rewrite name {rewriteName}({direction})." /-- Given an egg rewrite and a current goal with all the necessary information about the minimization problem, we find the appropriate rewrite to apply, and output the remaining goals. -/ def evalStep (step : EggRewrite) (vars params : List Name) (paramsDecls : List LocalDecl) - (tagsMap : HashMap String ℕ) : EquivalenceBuilder := fun eqvExpr g => g.withContext do + (tagsMap : HashMap String ℕ) : EquivalenceBuilder Unit := fun eqvExpr g => g.withContext do let tag ← liftMetaM <| do if step.location == "objFun" then return "objFun" @@ -55,43 +52,105 @@ def evalStep (step : EggRewrite) (vars params : List Name) (paramsDecls : List L let tagNum := tagsMap.find! step.location let atObjFun := tagNum == 0 - -- Build expexcted expression to generate the right rewrite condition. - let expectedTermStr := step.expectedTerm - let mut expectedExpr ← EggString.toExpr vars params expectedTermStr - if !atObjFun then - expectedExpr := Meta.mkLabel (Name.mkSimple tag) expectedExpr - let fvars := Array.mk <| vars.map (fun v => mkFVar (FVarId.mk v)) - let paramsFvars := Array.mk <| params.map (fun v => mkFVar (FVarId.mk v)) - let paramsDeclsIds := Array.mk <| paramsDecls.map (fun decl => mkFVar decl.fvarId) - let D ← instantiateMVars eqvExpr.domainLHS - expectedExpr ← - withLocalDeclD `p D fun p => do - Meta.withDomainLocalDecls D p fun xs prs => do - let expectedExpr := Expr.replaceFVars expectedExpr paramsFvars paramsDeclsIds - let expectedExpr := Expr.replaceFVars expectedExpr fvars xs - let expectedExpr := Expr.replaceFVars expectedExpr xs prs - mkLambdaFVars #[p] expectedExpr - - let (tacStx, isMap) ← findTactic atObjFun step.rewriteName step.direction + let (tacStx, isMap) ← findTactic step.rewriteName step.direction - -- Finally, apply the tactic that should solve all proof obligations. A mix of approaches using - -- `norm_num` in combination with the tactic provided by the user for this particular rewrite. - let tacStx : Syntax ← `(tactic| intros; + -- This tactic should solve the final goal. It uses the tactic provided by the user for this + -- particular rule in combination with `norm_num`. + let tacStx : Syntax ← `(tactic| try { $tacStx <;> norm_num_simp_pow } <;> try { norm_num_simp_pow }) if isMap then -- Maps, e.g., `map_objFun_log` are applied directly to the equivalence goal. - if let _ :: _ ← evalTacticAt tacStx g then + if let gs@(_ :: _) ← evalTacticAt tacStx g then + trace[CvxLean.debug] "`pre_dcp` failed to close goals: {gs}." throwPreDCPError "failed to apply {step.rewriteName}." else + -- Optimization and param free variables, and domain. Needed to build expressions and replace + -- variables appropriately. + let fvars := Array.mk <| vars.map (fun v => mkFVar (FVarId.mk v)) + let paramsFvars := Array.mk <| params.map (fun v => mkFVar (FVarId.mk v)) + let paramsDeclsIds := Array.mk <| paramsDecls.map (fun decl => mkFVar decl.fvarId) + let D ← instantiateMVars eqvExpr.domainLHS + + -- Convert string subexpressions given by `egg` into expressions. + let mut lhsSubExpr ← EggString.toExpr vars params step.subexprFrom + let mut rhsSubExpr ← EggString.toExpr vars params step.subexprTo + lhsSubExpr ← withLocalDeclD `p D fun p => do + Meta.withDomainLocalDecls D p fun xs prs => do + let lhsSubExpr := Expr.replaceFVars lhsSubExpr paramsFvars paramsDeclsIds + let lhsSubExpr := Expr.replaceFVars lhsSubExpr fvars xs + let lhsSubExpr := Expr.replaceFVars lhsSubExpr xs prs + mkLambdaFVars #[p] lhsSubExpr + rhsSubExpr ← withLocalDeclD `p D fun p => do + Meta.withDomainLocalDecls D p fun xs prs => do + let rhsSubExpr := Expr.replaceFVars rhsSubExpr paramsFvars paramsDeclsIds + let rhsSubExpr := Expr.replaceFVars rhsSubExpr fvars xs + let rhsSubExpr := Expr.replaceFVars rhsSubExpr xs prs + mkLambdaFVars #[p] rhsSubExpr + + -- Build expexcted expression to generate the right rewrite condition. Note that this has a hole + -- where the subexpression will be placed. For example, if + let mut expectedExpr ← EggString.toExpr vars params step.expectedTerm + if !atObjFun then + expectedExpr := Meta.mkLabel (Name.mkSimple tag) expectedExpr + expectedExpr ← + withLocalDeclD `p D fun p => do + Meta.withDomainLocalDecls D p fun xs prs => do + let expectedExpr := Expr.replaceFVars expectedExpr paramsFvars paramsDeclsIds + let expectedExpr := Expr.replaceFVars expectedExpr fvars xs + let expectedExpr := Expr.replaceFVars expectedExpr xs prs + withLocalDeclD `subexpr (← inferType (rhsSubExpr.mkAppNBeta #[p])) fun subexpr => + let expectedExpr := expectedExpr.replaceFVars #[mkFVar (FVarId.mk `subexpr)] #[subexpr] + mkLambdaFVars #[p, subexpr] expectedExpr + + -- The target expression is the expected expression above replacing `subexpr` with the + -- right-hand side subexpression. + let targetExpr ← withLocalDeclD `p D fun p => do + let targetExpr := expectedExpr.mkAppNBeta #[p, rhsSubExpr.mkAppNBeta #[p]] + mkLambdaFVars #[p] targetExpr + + -- This tactic first uses congruence to reduce the goal to the equality of the subexpressions + -- and then applies `tacStx`. + let tac : MVarId → FVarId → TacticM Unit := fun g p => do + -- Introduce other constraints. + let (_, g) ← g.intros + let mut g := g + -- Apply optimization variables to expresssions. + let lhsSubExprAtProb := lhsSubExpr.mkAppNBeta #[mkFVar p] + let rhsSubExprAtProb := rhsSubExpr.mkAppNBeta #[mkFVar p] + let expectedExprAtProb := expectedExpr.mkAppNBeta #[mkFVar p] + -- Distinguish between rewriting props and reals. No need to apply congruence if we are + -- rewriting the whole constraint. + if !(← inferType lhsSubExprAtProb).isProp then + -- Reduce to equality goal for real rewerites . + if !atObjFun then + let gs ← g.apply (mkConst ``Iff.of_eq) + if gs.length != 1 then + throwPreDCPError "failed to apply `iff_of_eq`." + g := gs[0]! + -- Apply congruence. + let subExprAtProbType ← inferType rhsSubExprAtProb + let congrLemma ← mkAppOptM ``congrArg + #[subExprAtProbType, none, lhsSubExprAtProb, rhsSubExprAtProb, expectedExprAtProb] + let gs ← g.apply congrLemma + if gs.length != 1 then + throwPreDCPError "failed to apply `congrArg`." + g := gs[0]! + -- Apply the tactic. + let gs ← evalTacticAt tacStx g + if gs.length != 0 then + trace[CvxLean.debug] "`pre_dcp` failed to close goals: {gs}." + throwPreDCPError "failed to apply {step.rewriteName}." + pure () + -- Rewrites use the machinery from `Tactic.Basic.RewriteOpt`. if atObjFun then - rewriteObjBuilder true tacStx (some expectedExpr) eqvExpr g + rewriteObjBuilderFromTactic true tac targetExpr eqvExpr g else - rewriteConstrBuilder true (Name.mkSimple tag) tacStx (some expectedExpr) eqvExpr g + let _ ← rewriteConstrBuilderFromTactic true (Name.mkSimple tag) tac targetExpr eqvExpr g -def preDCPBuilder : EquivalenceBuilder := fun eqvExpr g => g.withContext do +def preDCPBuilder : EquivalenceBuilder Unit := fun eqvExpr g => g.withContext do let lhs ← eqvExpr.toMinimizationExprLHS -- Get optimization variables. @@ -123,6 +182,7 @@ def preDCPBuilder : EquivalenceBuilder := fun eqvExpr g => g.withContext do -- Get goal as tree and create tags map. let (gStr, domainConstrs) ← ExtendedEggTree.fromMinimization lhs varsStr + let probSize := (gStr.map fun (_, t) => t.size).fold 0 Nat.add let mut tagsMap := HashMap.empty tagsMap := tagsMap.insert "objFun" 0 let mut idx := 1 @@ -151,8 +211,7 @@ def preDCPBuilder : EquivalenceBuilder := fun eqvExpr g => g.withContext do let diff := aft - bef dbg_trace s!"Egg time: {diff} ms." dbg_trace s!"Number of steps: {steps.size}." - let size := (gStr.map fun (_, t) => t.size).fold 0 Nat.add - dbg_trace s!"Term size: {size}." + dbg_trace s!"Term size: {probSize}." dbg_trace s!"Term JSON: {eggMinimization.toJson}." -- Apply steps. diff --git a/CvxLean/Tactic/PreDCP/RewriteMapCmd.lean b/CvxLean/Tactic/PreDCP/RewriteMapCmd.lean deleted file mode 100644 index 07717350..00000000 --- a/CvxLean/Tactic/PreDCP/RewriteMapCmd.lean +++ /dev/null @@ -1,48 +0,0 @@ -import Lean -import Std.Linter.UnreachableTactic -import CvxLean.Tactic.PreDCP.RewriteMapExt - -namespace CvxLean - -open Lean Parser - -syntax (name := registerRewriteMap) - "register_rewrite_map " str " ; " str " => " str " := " tactic ";" : command - -syntax (name := registerRewriteMapBidirectional) - "register_rewrite_map " str " ; " str " <=> " str " := " tactic ";" : command - -macro_rules - | `(register_rewrite_map $rwName ; $rwTarget <=> $rwGoal := $tac;) => - if let some rwNameStr := Syntax.isStrLit? rwName then - let rwNameRev := Syntax.mkStrLit (rwNameStr ++ "-rev") - `(register_rewrite_map $rwName ; $rwTarget => $rwGoal := $tac; - register_rewrite_map $rwNameRev ; $rwGoal => $rwTarget := $tac;) - else `(throwError "register_rewrite_map error: expected string") - -syntax (name := registerObjFunRewriteMap) - "register_objFun_rewrite_map " str " ; " str " => " str " := " tactic ";" : command - -open Lean.Elab Lean.Elab.Command - -set_option linter.unreachableTactic false in -@[command_elab registerRewriteMap] def elabRegisterEggRewrite : CommandElab -| `(register_rewrite_map $rwName ; $rwTarget => $rwGoal := $tac;) => do - let rwNameStr := rwName.getString - -- NOTE: We ignore this for now. - let _rwTargetStr := rwTarget.getString - let _rwGoalStr := rwGoal.getString - liftTermElabM <| addRewriteMapEntry rwNameStr tac false -| _ => throwUnsupportedSyntax - -set_option linter.unreachableTactic false in -@[command_elab registerObjFunRewriteMap] def elabRegisterObjFunEggRewrite : CommandElab -| `(register_objFun_rewrite_map $rwName ; $rwTarget => $rwGoal := $tac;) => do - let rwNameStr := rwName.getString - -- NOTE: We ignore this for now. - let _rwTargetStr := rwTarget.getString - let _rwGoalStr := rwGoal.getString - liftTermElabM <| addRewriteMapEntry rwNameStr tac true -| _ => throwUnsupportedSyntax - -end CvxLean diff --git a/CvxLean/Tactic/PreDCP/RewriteMapExt.lean b/CvxLean/Tactic/PreDCP/RewriteMapExt.lean deleted file mode 100644 index ecf37ef1..00000000 --- a/CvxLean/Tactic/PreDCP/RewriteMapExt.lean +++ /dev/null @@ -1,40 +0,0 @@ -import Lean - -namespace CvxLean - -open Lean Lean.Meta - -/-- An entry is the name of the rewrite, the corresponding tactic, and whether -it is an objective function map or not. -/ -def RewriteMapExtEntry : Type := String × ((TSyntax `tactic) × Bool) -deriving Inhabited - -def RewriteMapExtState : Type := HashMap String ((TSyntax `tactic) × Bool) -deriving Inhabited - -/-- Environment extension to store the mapping between rewrite names in egg and -Lean tactics. -/ -def RewriteMapExt : Type := - SimplePersistentEnvExtension RewriteMapExtEntry RewriteMapExtState -deriving Inhabited - -initialize rewriteMapExt : RewriteMapExt ← - registerSimplePersistentEnvExtension { - addImportedFn := fun as => - mkStateFromImportedEntries (fun s (k, v) => s.insert k v) default as, - addEntryFn := fun s (k, v) => s.insert k v, - } - -/-- Add a new rewrite mapping to the environment. -/ -def addRewriteMapEntry (rwName : String) (tac : TSyntax `tactic) (mapObjFun : Bool) : MetaM Unit := do - setEnv <| rewriteMapExt.addEntry (← getEnv) (rwName, (tac, mapObjFun)) - -/-- Given rewrite name, return associated tactic in the environment. -/ -def getTacticFromRewriteName (rwName : String) : MetaM (Option (TSyntax `tactic × Bool)) := do - return (rewriteMapExt.getState (← getEnv)).find? rwName - -/-- Return all the saved rewrite names. -/ -def getRegisteredRewriteNames : MetaM (List String) := do - return (rewriteMapExt.getState (← getEnv)).toList.map (·.1) - -end CvxLean diff --git a/CvxLean/Tactic/PreDCP/RewriteMapLibrary.lean b/CvxLean/Tactic/PreDCP/RewriteMapLibrary.lean deleted file mode 100644 index 50117d42..00000000 --- a/CvxLean/Tactic/PreDCP/RewriteMapLibrary.lean +++ /dev/null @@ -1,362 +0,0 @@ -import CvxLean.Lib.Equivalence -import CvxLean.Tactic.PreDCP.RewriteMapCmd -import CvxLean.Tactic.Arith.Arith - -/-- Attempt to close the goal using the lemma specified with a combination of -`simp` and `rw`. -/ -syntax "simp_or_rw" "[" Lean.Parser.Tactic.rwRule "]" : tactic - -macro_rules -| `(tactic| simp_or_rw [$rule]) => - let symm := !rule.raw[0].isNone - match rule.raw[1] with - | `(term| $e:term) => - if symm then - `(tactic| (try { simp only [←$e:term] } <;> repeat' (rw [←$e:term]))) - else - `(tactic| (try { simp only [$e:term] } <;> repeat' (rw [$e:term]))) - -namespace CvxLean - -/- Objective function rules. -/ - -register_objFun_rewrite_map "map_objFun_log"; "(prob (objFun ?a) ?cs)" => "(prob (objFun (log ?a)) ?cs)" := - apply Minimization.Equivalence.map_objFun_log (by positivity!); - -register_objFun_rewrite_map "map_objFun_sq"; "(prob (objFun ?a) ?cs)" => "(prob (objFun (pow ?a 2)) ?cs)" := - apply Minimization.Equivalence.map_objFun_sq (by positivity!); - - -/- Equality rules. -/ - -register_rewrite_map "log_eq_log" ; "(eq ?a ?b)" => "(eq (log ?a) (log ?b))" := - rw [Real.log_eq_log (by positivity!) (by positivity!)]; - - -/- Less than or equal rules. -/ - -register_rewrite_map "le_sub_iff_add_le" ; "(le ?a (sub ?b ?c))" => "(le (add ?a ?c) ?b)" := - rw [le_sub_iff_add_le]; - -register_rewrite_map "le_sub_iff_add_le-rev" ; "(le (add ?a ?c) ?b)" => "(le ?a (sub ?b ?c))" := - rw [←le_sub_iff_add_le]; - -register_rewrite_map "sub_le_iff_le_add"; "(le (sub ?a ?c) ?b)" => "(le ?a (add ?b ?c))" := - rewrite [sub_le_iff_le_add]; - -register_rewrite_map "sub_le_iff_le_add-rev"; "(le ?a (add ?b ?c))" => "(le (sub ?a ?c) ?b)" := - rw [←sub_le_iff_le_add]; - -register_rewrite_map "div_le_iff" ; "(le (div ?a ?b) ?c)" => "(le ?a (mul ?b ?c))" := - rw [div_le_iff (by positivity!)]; - -register_rewrite_map "div_le_iff-rev" ; "(le (div ?a ?b) ?c)" => "(le ?a (mul ?b ?c))" := - rw [←div_le_iff (by positivity!)]; - -register_rewrite_map "div_le_one-rev" ; "(le ?a ?b)" => "(le (div ?a ?b) 1)" := - rw [←div_le_one (by positivity!)]; - -register_rewrite_map "log_le_log" ; "(le (log ?a) (log ?b))" => "(le ?a ?b)" := - rw [Real.log_le_log_iff (by positivity!) (by positivity!)]; - -register_rewrite_map "log_le_log-rev" ; "(le ?a ?b)" => "(le (log ?a) (log ?b))" := - rw [←Real.log_le_log_iff (by positivity!) (by positivity!)]; - -register_rewrite_map "pow_two_le_pow_two"; "(le (pow ?a 2) (pow ?b 2))" => "(le ?a ?b)":= - rw [Real.pow_two_le_pow_two (by positivity!) (by positivity!)]; - -register_rewrite_map "pow_two_le_pow_two-rev"; "(le ?a ?b)" => "(le (pow ?a 2) (pow ?b 2))" := - rw [←Real.pow_two_le_pow_two (by positivity!) (by positivity!)]; - - -/- Field rules. -/ - -register_rewrite_map "neg_neg" ; "(neg (neg ?a))" => "?a" := - simp_or_rw [neg_neg (G := ℝ)]; - -register_rewrite_map "add_zero" ; "(add ?a 0)" => "?a" := - simp_or_rw [add_zero]; - -register_rewrite_map "add_comm" ; "(add ?a ?b)" => "(add ?b ?a)" := - simp_or_rw [add_comm]; - -register_rewrite_map "add_assoc" ; "(add (add ?a ?b) ?c)" => "(add ?a (add ?b ?c))" := - simp_or_rw [add_assoc]; - -register_rewrite_map "sub_self" ; "(sub ?a ?a)" => "0" := - simp_or_rw [sub_self]; - -register_rewrite_map "one_mul" ; "(mul 1 ?a)" => "?a" := - simp_or_rw [one_mul]; - --- Exception, we cannot find the pattern otherwise. -register_rewrite_map "one_mul-rev" ; "?a" => "(mul 1 ?a)" := - norm_num; - -register_rewrite_map "mul_zero" ; "(mul ?a 0)" => "0" := - simp_or_rw [mul_zero (M₀ := ℝ)]; - -register_rewrite_map "mul_comm" ; "(mul ?a ?b)" => "(mul ?b ?a)" := - simp_or_rw [mul_comm]; - -register_rewrite_map "mul_assoc" ; "(mul (mul ?a ?b) ?c)" => "(mul ?a (mul ?b ?c))" := - simp_or_rw [mul_assoc]; - -register_rewrite_map "sub_eq_add_neg"; "(sub ?a ?b)" => "(add ?a (neg ?b))" := - simp_or_rw [sub_eq_add_neg (G := ℝ)]; - -register_rewrite_map "sub_eq_add_neg-rev"; "(add ?a (neg ?b))" => "(sub ?a ?b)" := - simp_or_rw [←sub_eq_add_neg (G := ℝ)]; - -register_rewrite_map "add_sub" ; "(add ?a (sub ?b ?c))" => "(sub (add ?a ?b) ?c)" := - simp_or_rw [add_sub]; - -register_rewrite_map "add_sub-rev" ; "(add ?a (sub ?b ?c))" => "(sub (add ?a ?b) ?c)" := - simp_or_rw [←add_sub]; - -register_rewrite_map "add_mul" ; "(mul (add ?a ?b) ?c)" => "(add (mul ?a ?c) (mul ?b ?c))" := - simp_or_rw [add_mul]; - -register_rewrite_map "add_mul-rev" ; "(add (mul ?a ?c) (mul ?b ?c))" => "(mul (add ?a ?b) ?c)" := - simp_or_rw [←add_mul]; - -register_rewrite_map "mul_add" ; "(mul ?a (add ?b ?c))" => "(add (mul ?a ?b) (mul ?a ?c))" := - simp_or_rw [mul_add]; - -register_rewrite_map "mul_add-rev" ; "(add (mul ?a ?b) (mul ?a ?c))" => "(mul ?a (add ?b ?c))" := - simp_or_rw [←mul_add]; - -register_rewrite_map "mul_sub" ; "(mul ?a (sub ?b ?c))" => "(sub (mul ?a ?b) (mul ?a ?c))" := - simp_or_rw [←mul_sub]; - -register_rewrite_map "mul_sub-rev" ; "(sub (mul ?a ?b) (mul ?a ?c))" => "(mul ?a (sub ?b ?c))" := - simp_or_rw [←mul_sub]; - -register_rewrite_map "div_one" ; "(div ?a 1)" => "?a" := - simp_or_rw [div_one]; - -register_rewrite_map "add_div" ; "(div (add ?a ?b) ?c)" => "(add (div ?a ?c) (div ?b ?c))" := - simp_or_rw [add_div]; - -register_rewrite_map "add_div-rev" ; "(add (div ?a ?c) (div ?b ?c))" => "(div (add ?a ?b) ?c)" := - simp_or_rw [←add_div]; - -register_rewrite_map "mul_div" ; "(mul ?a (div ?b ?c))" => "(div (mul ?a ?b) ?c)" := - simp_or_rw [mul_div (G := ℝ)]; - -register_rewrite_map "mul_div-rev" ; "(div (mul ?a ?b) ?c)" => "(mul ?a (div ?b ?c))" := - simp_or_rw [←mul_div (G := ℝ)]; - -register_rewrite_map "div_mul" ; "(mul (div ?a ?b) ?c)" => "(div ?a (div ?b ?c))" := - simp_or_rw [div_mul]; - -register_rewrite_map "div_mul-rev" ; "(div ?a (div ?b ?c))" => "(mul (div ?a ?b) ?c)" := - simp_or_rw [←div_mul]; - -register_rewrite_map "mul_div_mul_left" ; "(div (mul ?c ?a) (mul ?c ?b))" => "(div ?a ?b)" := - simp_or_rw [mul_div_mul_left (G₀ := ℝ) _ _ (by positivity!)]; - -register_rewrite_map "div_div" ; "(div (div ?a ?b) ?c)" => "(div ?a (mul ?b ?c))" := - simp_or_rw [div_div]; - -register_rewrite_map "div_div-rev" ; "(div ?a (mul ?b ?c))" => "(div (div ?a ?b) ?c)" := - simp_or_rw [←div_div]; - -register_rewrite_map "div_self" ; "(div ?a ?a)" => "1" := - simp_or_rw [div_self (G₀ := ℝ) (by positivity!)]; - -register_rewrite_map "inv_eq_one_div" ; "(inv ?a)" => "(div 1 ?a)" := - simp_or_rw [inv_eq_one_div (G := ℝ)]; - -register_rewrite_map "inv_inv" ; "(inv (inv ?a))" => "?a" := - simp_or_rw [inv_inv]; - - -/- Power and square root rules. -/ - -register_rewrite_map "one_pow"; "(pow 1 ?a)" => "1" := - simp_or_rw [Real.one_rpow]; - -register_rewrite_map "pow_one"; "(pow ?a 1)" => "?a" := - simp_or_rw [Real.rpow_one]; - -register_rewrite_map "pow_add"; "(pow ?a (add ?b ?c))" => "(mul (pow ?a ?b) (pow ?a ?c))" := - simp_or_rw [Real.rpow_add (by positivity!)]; - -register_rewrite_map "pow_add-rev"; "(mul (pow ?a ?b) (pow ?a ?c))" => "(pow ?a (add ?b ?c))" := - simp_or_rw [←Real.rpow_add (by positivity!)]; - -register_rewrite_map "mul_pow"; "(mul (pow ?a ?n) (pow ?b ?n))" => "(pow (mul ?a ?b) ?n)" := - simp_or_rw [Real.mul_rpow (by positivity!) (by positivity!)]; - -register_rewrite_map "mul_pow-rev"; "(mul (pow ?a ?n) (pow ?b ?n))" => "(pow (mul ?a ?b) ?n)" := - simp_or_rw [←Real.mul_rpow (by positivity!) (by positivity!)]; - -register_rewrite_map "pow_mul"; "(pow ?a (mul ?n ?m))" => "(pow (pow ?a ?n) ?m)" := - simp_or_rw [Real.rpow_mul (by positivity!)]; - -register_rewrite_map "pow_mul-rev"; "(pow (pow ?a ?n) ?m)" => "(pow ?a (mul ?n ?m))" := - simp_or_rw [←Real.rpow_mul (by positivity!)]; - -register_rewrite_map "div_pow"; "(pow (div ?a ?b) ?n)" => "(div (pow ?a ?n) (pow ?b ?n))" := - simp_or_rw [Real.div_rpow (by positivity!) (by positivity!)]; - -register_rewrite_map "div_pow-rev"; "(div (pow ?a ?n) (pow ?b ?n))" => "(pow (div ?a ?b) ?n)" := - simp_or_rw [←Real.div_rpow (by positivity!) (by positivity!)]; - -register_rewrite_map "pow_sub"; "(pow ?a (sub ?b ?c))" => "(div (pow ?a ?b) (pow ?a ?c))" := - simp_or_rw [Real.rpow_sub (by positivity!)]; - -register_rewrite_map "pow_sub-rev"; "(div (pow ?a ?b) (pow ?a ?c))" => "(pow ?a (sub ?b ?c))" := - simp_or_rw [←Real.rpow_sub (by positivity!)]; - -register_rewrite_map "div_pow_eq_mul_pow_neg" ; "(div ?a (pow ?b ?c))" => "(mul ?a (pow ?b (neg ?c)))" := - simp_or_rw [Real.div_pow_eq_mul_pow_neg (by positivity!)]; - -register_rewrite_map "div_pow_eq_mul_pow_neg-rev" ; "(div ?a (pow ?b ?c))" => "(mul ?a (pow ?b (neg ?c)))" := - simp_or_rw [←Real.div_pow_eq_mul_pow_neg (by positivity!)]; - -register_rewrite_map "one_div_eq_pow_neg_one"; "(div 1 ?a)" => "(pow ?a (neg 1))" := - simp_or_rw [Real.one_div_eq_pow_neg_one (by positivity!)]; - -register_rewrite_map "sqrt_eq_rpow" ; "(sqrt ?a)" => "(pow ?a 0.5)" := - simp_or_rw [Real.sqrt_eq_rpow]; - -register_rewrite_map "sqrt_eq_rpow-rev" ; "(pow ?a 0.5)" => "(sqrt ?a)" := - simp_or_rw [←Real.sqrt_eq_rpow]; - -register_rewrite_map "sqrt_mul" ; "(sqrt (mul ?a ?b))" => "(mul (sqrt ?a) (sqrt ?b))" := - simp_or_rw [Real.sqrt_mul (by positivity!)]; - -register_rewrite_map "sqrt_mul-rev" ; "(mul (sqrt ?a) (sqrt ?b))" => "(sqrt (mul ?a ?b))" := - simp_or_rw [←Real.sqrt_mul (by positivity!)]; - -register_rewrite_map "pow_two"; "(pow ?a 2)" => "(mul ?a ?a)" := - (norm_cast; simp_or_rw [pow_two]); - -register_rewrite_map "pow_two-rev"; "(mul ?a ?a)" => "(pow ?a 2)" := - (norm_cast; simp_or_rw [←pow_two]); - -register_rewrite_map "pow_half_two"; "(pow (pow ?a 0.5) 2)" => "?a" := - simp_or_rw [Real.pow_half_two (by positivity!)]; - --- Exception, technically ← but we cannot find the pattern otherwise. -register_rewrite_map "pow_half_two-rev"; "?a" => "(pow (pow ?a 0.5) 2)" := - simp_or_rw [Real.pow_half_two (by positivity!)]; - -register_rewrite_map "binomial_two"; "(pow (add ?a ?b) 2)" => "(add (pow ?a 2) (add (mul 2 (mul ?a ?b)) (pow ?b 2)))" := - simp_or_rw [Real.binomial_two]; - -register_rewrite_map "rpow_eq_mul_rpow_pred"; "(pow ?a ?b)" => "(mul ?a (pow ?a (sub ?b 1)))" := - simp_or_rw [Real.rpow_eq_mul_rpow_pred (by positivity!)]; - -register_rewrite_map "inv_eq_pow_neg_one"; "(inv ?a)" => "(pow ?a (neg 1))" := - simp_or_rw [Real.inv_eq_pow_neg_one (by positivity!)]; - - -/- Exponential and logarithm rules. -/ - -register_rewrite_map "exp_add" ; "(exp (add ?a ?b))" => "(mul (exp ?a) (exp ?b))" := - simp_or_rw [Real.exp_add]; - -register_rewrite_map "exp_add-rev" ; "(mul (exp ?a) (exp ?b))" => "(exp (add ?a ?b))" := - simp_or_rw [←Real.exp_add]; - -register_rewrite_map "exp_sub" ; "(exp (sub ?a ?b))" => "(div (exp ?a) (exp ?b))" := - simp_or_rw [Real.exp_sub]; - -register_rewrite_map "exp_sub-rev" ; "(div (exp ?a) (exp ?b))" => "(exp (sub ?a ?b))" := - simp_or_rw [←Real.exp_sub]; - -register_rewrite_map "exp_mul" ; "(exp (mul ?a ?b))" => "(pow (exp ?a) ?b)" := - simp_or_rw [Real.exp_mul]; - -register_rewrite_map "exp_mul-rev" ; "(pow (exp ?a) ?b)" => "(exp (mul ?a ?b))" := - simp_or_rw [←Real.exp_mul]; - -register_rewrite_map "exp_neg_eq_one_div"; "(exp (neg ?a))" => "(div 1 (exp ?a))" := - simp_or_rw [Real.exp_neg_eq_one_div]; - -register_rewrite_map "exp_neg_eq_one_div-rev" ; "(div 1 (exp ?a))" => "(exp (neg ?a))" := - simp_or_rw [←Real.exp_neg_eq_one_div]; - -register_rewrite_map "log_mul" ; "(log (mul ?a ?b))" => "(add (log ?a) (log ?b))" := - simp_or_rw [Real.log_mul (by positivity!) (by positivity!)]; - -register_rewrite_map "log_mul-rev"; "(add (log ?a) (log ?b))" => "(log (mul ?a ?b))" := - simp_or_rw [←Real.log_mul (by positivity!) (by positivity!)]; - -register_rewrite_map "log_div" ; "(log (div ?a ?b))" => "(sub (log ?a) (log ?b))" := - simp_or_rw [Real.log_div (by positivity!) (by positivity!)]; - -register_rewrite_map "log_div-rev"; "(sub (log ?a) (log ?b))" => "(log (div ?a ?b))" := - simp_or_rw [←Real.log_div (by positivity!) (by positivity!)]; - -register_rewrite_map "log_pow" ; "(log (pow ?a ?b))" => "(mul ?b (log ?a))" := - simp_or_rw [Real.log_rpow (by positivity!)]; - -register_rewrite_map "log_pow-rev" ; "(mul ?b (log ?a))" => "(log (pow ?a ?b))" := - simp_or_rw [←Real.log_rpow (by positivity!)]; - --- Special type of rule as it only works if the exponent is a natural number. -register_rewrite_map "log_pow_nat" ; "(log (pow ?a ?b))" => "(mul ?b (log ?a))" := - (norm_cast; simp_or_rw [Real.log_pow]); - -register_rewrite_map "log_pow_nat-rev" ; "(mul ?b (log ?a))" => "(log (pow ?a ?b))" := - (norm_cast; simp_or_rw [←Real.log_pow]); - -register_rewrite_map "exp_log" ; "(exp (log ?a))" => "?a" := - simp_or_rw [Real.exp_log (by positivity!)]; - -register_rewrite_map "log_exp" ; "(log (exp ?a))" => "?a" := - simp_or_rw [Real.log_exp]; - - -/- Absolute value rules. -/ - -register_rewrite_map "abs_nonneg" ; "(abs ?a)" => "?a" := - simp_or_rw [abs_eq_self.mpr (by positivity!)]; - -register_rewrite_map "abs_nonpos" ; "(abs ?a)" => "(neg ?a)" := - simp_or_rw [abs_eq_neg_self.mpr (by linarith)]; - - -/- Atom folding. -/ - -register_rewrite_map "xexp_fold"; "(mul ?a (exp ?a))" => "(xexp ?a)" := - rfl; - -register_rewrite_map "xexp_unfold"; "(xexp ?a)" => "(mul ?a (exp ?a))" := - rfl; - -register_rewrite_map "entr_fold"; "(neg (mul ?a (log ?a)))" => "(entr ?a)" := - rfl; - -register_rewrite_map "entr_unfold"; "(entr ?a)" => "(neg (mul ?a (log ?a)))" := - rfl; - -register_rewrite_map "qol_fold"; "(div (pow ?a 2) ?b)" => "(qol ?a ?b)" := - rfl; - -register_rewrite_map "qol_unfold"; "(qol ?a ?b)" => "(div (pow ?a 2) ?b)" := - rfl; - -register_rewrite_map "geo_fold"; "(sqrt (mul ?a ?b))" => "(geo ?a ?b)" := - rfl; - -register_rewrite_map "geo_unfold"; "(geo ?a ?b)" => "(sqrt (mul ?a ?b))" := - rfl; - -register_rewrite_map "lse_fold"; "(log (add (exp ?a) (exp ?b)))" => "(lse ?a ?b)" := - rfl; - -register_rewrite_map "lse_unfold"; "(lse ?a ?b)" => "(log (add (exp ?a) (exp ?b)))" := - rfl; - -register_rewrite_map "norm2_fold"; "(sqrt (add (pow ?a 2) (pow ?b 2)))" => "(norm2 ?a ?b)" := - rfl; - -register_rewrite_map "norm2_unfold"; "(norm2 ?a ?b)" => "(sqrt (add (pow ?a 2) (pow ?b 2)))" := - rfl; - -end CvxLean diff --git a/CvxLean/Tactic/PreDCP/RuleToTacticCmd.lean b/CvxLean/Tactic/PreDCP/RuleToTacticCmd.lean new file mode 100644 index 00000000..df3be186 --- /dev/null +++ b/CvxLean/Tactic/PreDCP/RuleToTacticCmd.lean @@ -0,0 +1,45 @@ +import Lean +import CvxLean.Tactic.PreDCP.RuleToTacticExt + +namespace CvxLean + +open Lean Parser + +syntax (name := registerRuleToTactic) + "register_rule_to_tactic " str " ; " str " => " str " := " tactic ";" : command + +syntax (name := registerRuleToTacticBidirectional) + "register_rule_to_tactic " str " ; " str " <=> " str " := " tactic ";" : command + +macro_rules + | `(register_rule_to_tactic $rwName ; $rwTarget <=> $rwGoal := $tac;) => + if let some rwNameStr := Syntax.isStrLit? rwName then + let rwNameRev := Syntax.mkStrLit (rwNameStr ++ "-rev") + `(register_rule_to_tactic $rwName ; $rwTarget => $rwGoal := $tac; + register_rule_to_tactic $rwNameRev ; $rwGoal => $rwTarget := $tac;) + else `(throwError "`register_rule_to_tactic` error: expected string.") + +syntax (name := registerObjFunRuleToTactic) + "register_objFun_rule_to_tactic " str " ; " str " => " str " := " tactic ";" : command + +open Lean.Elab Lean.Elab.Command + +@[command_elab registerRuleToTactic] def elabRuleToTactic : CommandElab +| `(register_rule_to_tactic $rwName ; $rwTarget => $rwGoal := $tac;) => do + let rwNameStr := rwName.getString + -- NOTE: We ignore this for now. + let _rwTargetStr := rwTarget.getString + let _rwGoalStr := rwGoal.getString + liftTermElabM <| addRuleToTacticEntry rwNameStr tac false +| _ => throwUnsupportedSyntax + +@[command_elab registerObjFunRuleToTactic] def elabRegisterObjFunRuleToTactic : CommandElab +| `(register_objFun_rule_to_tactic $rwName ; $rwTarget => $rwGoal := $tac;) => do + let rwNameStr := rwName.getString + -- NOTE: We ignore this for now. + let _rwTargetStr := rwTarget.getString + let _rwGoalStr := rwGoal.getString + liftTermElabM <| addRuleToTacticEntry rwNameStr tac true +| _ => throwUnsupportedSyntax + +end CvxLean diff --git a/CvxLean/Tactic/PreDCP/RuleToTacticExt.lean b/CvxLean/Tactic/PreDCP/RuleToTacticExt.lean new file mode 100644 index 00000000..b9f8ef8b --- /dev/null +++ b/CvxLean/Tactic/PreDCP/RuleToTacticExt.lean @@ -0,0 +1,48 @@ +import Lean + +/-! +# Egg rewrite rule to tactic (extension) + +This files defines the persistent environment extension used to store translation of `egg` rewrite +rule names to Lean tactics. +-/ + +namespace CvxLean + +open Lean Meta + +/-- An entry is the name of the rewrite, the corresponding tactic, and whether +it is an objective function map or not. -/ +def RuleToTacticExtEntry : Type := String × ((TSyntax `tactic) × Bool) +deriving Inhabited + +def RuleToTacticExtState : Type := HashMap String ((TSyntax `tactic) × Bool) +deriving Inhabited + +/-- Environment extension to store the mapping between rewrite names in egg and +Lean tactics. -/ +def RuleToTacticExt : Type := + SimplePersistentEnvExtension RuleToTacticExtEntry RuleToTacticExtState +deriving Inhabited + +initialize ruleToTacticExt : RuleToTacticExt ← + registerSimplePersistentEnvExtension { + addImportedFn := fun as => + mkStateFromImportedEntries (fun s (k, v) => s.insert k v) default as, + addEntryFn := fun s (k, v) => s.insert k v, + } + +/-- Add a new rewrite mapping to the environment. -/ +def addRuleToTacticEntry (rwName : String) (tac : TSyntax `tactic) (mapObjFun : Bool) : + MetaM Unit := do + setEnv <| ruleToTacticExt.addEntry (← getEnv) (rwName, (tac, mapObjFun)) + +/-- Given rewrite name, return associated tactic in the environment. -/ +def getTacticFromRewriteName (rwName : String) : MetaM (Option (TSyntax `tactic × Bool)) := do + return (ruleToTacticExt.getState (← getEnv)).find? rwName + +/-- Return all the saved rewrite names. -/ +def getRegisteredRewriteNames : MetaM (List String) := do + return (ruleToTacticExt.getState (← getEnv)).toList.map (·.1) + +end CvxLean diff --git a/CvxLean/Tactic/PreDCP/RuleToTacticLibrary.lean b/CvxLean/Tactic/PreDCP/RuleToTacticLibrary.lean new file mode 100644 index 00000000..74c8af2b --- /dev/null +++ b/CvxLean/Tactic/PreDCP/RuleToTacticLibrary.lean @@ -0,0 +1,376 @@ +import CvxLean.Lib.Equivalence +import CvxLean.Tactic.PreDCP.RuleToTacticCmd +import CvxLean.Tactic.Arith.Arith + +/-! + +-/ + +namespace CvxLean + +set_option linter.unreachableTactic false + +/- Objective function rules. -/ + +register_objFun_rule_to_tactic + "map_objFun_log"; "(prob (objFun ?a) ?cs)" => "(prob (objFun (log ?a)) ?cs)" := + apply Minimization.Equivalence.map_objFun_log (by positivity!); + +register_objFun_rule_to_tactic + "map_objFun_sq"; "(prob (objFun ?a) ?cs)" => "(prob (objFun (pow ?a 2)) ?cs)" := + apply Minimization.Equivalence.map_objFun_sq (by positivity!); + + +/- Equality rules. -/ + +register_rule_to_tactic "log_eq_log" ; "(eq ?a ?b)" => "(eq (log ?a) (log ?b))" := + apply Real.log_eq_log (by positivity!) (by positivity!); + +register_rule_to_tactic "log_eq_log-rev" ; "(eq (log ?a) (log ?b))" => "(eq ?a ?b)" := + apply (Real.log_eq_log (by positivity!) (by positivity!)).symm; + + +/- Less than or equal rules. -/ + +register_rule_to_tactic "le_sub_iff_add_le" ; "(le ?a (sub ?b ?c))" => "(le (add ?a ?c) ?b)" := + apply le_sub_iff_add_le; + +register_rule_to_tactic "le_sub_iff_add_le-rev" ; "(le (add ?a ?c) ?b)" => "(le ?a (sub ?b ?c))" := + apply le_sub_iff_add_le.symm; + +register_rule_to_tactic "sub_le_iff_le_add"; "(le (sub ?a ?c) ?b)" => "(le ?a (add ?b ?c))" := + apply sub_le_iff_le_add; + +register_rule_to_tactic "sub_le_iff_le_add-rev"; "(le ?a (add ?b ?c))" => "(le (sub ?a ?c) ?b)" := + apply sub_le_iff_le_add.symm; + +register_rule_to_tactic "div_le_iff" ; "(le (div ?a ?b) ?c)" => "(le ?a (mul ?b ?c))" := + apply div_le_iff (by positivity!); + +register_rule_to_tactic "div_le_iff-rev" ; "(le (div ?a ?b) ?c)" => "(le ?a (mul ?b ?c))" := + apply (div_le_iff (by positivity!)).symm; + +register_rule_to_tactic "div_le_one-rev" ; "(le ?a ?b)" => "(le (div ?a ?b) 1)" := + apply (div_le_one (by positivity!)).symm; + +register_rule_to_tactic "log_le_log" ; "(le (log ?a) (log ?b))" => "(le ?a ?b)" := + apply Real.log_le_log_iff (by positivity!) (by positivity!); + +register_rule_to_tactic "log_le_log-rev" ; "(le ?a ?b)" => "(le (log ?a) (log ?b))" := + apply (Real.log_le_log_iff (by positivity!) (by positivity!)).symm; + +register_rule_to_tactic "pow_two_le_pow_two"; "(le (pow ?a 2) (pow ?b 2))" => "(le ?a ?b)":= + apply Real.pow_two_le_pow_two (by positivity!) (by positivity!); + +register_rule_to_tactic "pow_two_le_pow_two-rev"; "(le ?a ?b)" => "(le (pow ?a 2) (pow ?b 2))" := + apply (Real.pow_two_le_pow_two (by positivity!) (by positivity!)).symm; + + +/- Field rules. -/ + +register_rule_to_tactic "neg_neg" ; "(neg (neg ?a))" => "?a" := + apply neg_neg (G := ℝ); + +register_rule_to_tactic "add_zero" ; "(add ?a 0)" => "?a" := + apply add_zero; + +register_rule_to_tactic "add_comm" ; "(add ?a ?b)" => "(add ?b ?a)" := + apply add_comm; + +register_rule_to_tactic "add_assoc" ; "(add (add ?a ?b) ?c)" => "(add ?a (add ?b ?c))" := + apply add_assoc; + +register_rule_to_tactic "sub_self" ; "(sub ?a ?a)" => "0" := + apply sub_self; + +register_rule_to_tactic "one_mul" ; "(mul 1 ?a)" => "?a" := + apply one_mul; + +-- Exception, we cannot find the pattern otherwise. +register_rule_to_tactic "one_mul-rev" ; "?a" => "(mul 1 ?a)" := + norm_num; + +register_rule_to_tactic "mul_zero" ; "(mul ?a 0)" => "0" := + apply mul_zero (M₀ := ℝ); + +register_rule_to_tactic "mul_comm" ; "(mul ?a ?b)" => "(mul ?b ?a)" := + apply mul_comm; + +register_rule_to_tactic "mul_assoc" ; "(mul (mul ?a ?b) ?c)" => "(mul ?a (mul ?b ?c))" := + apply mul_assoc; + +register_rule_to_tactic "sub_eq_add_neg"; "(sub ?a ?b)" => "(add ?a (neg ?b))" := + apply sub_eq_add_neg (G := ℝ); + +register_rule_to_tactic "sub_eq_add_neg-rev"; "(add ?a (neg ?b))" => "(sub ?a ?b)" := + apply (sub_eq_add_neg (G := ℝ) _ _).symm; + +register_rule_to_tactic "add_sub" ; "(add ?a (sub ?b ?c))" => "(sub (add ?a ?b) ?c)" := + apply add_sub; + +register_rule_to_tactic "add_sub-rev" ; "(add ?a (sub ?b ?c))" => "(sub (add ?a ?b) ?c)" := + apply (add_sub _ _ _).symm; + +register_rule_to_tactic "add_mul" ; "(mul (add ?a ?b) ?c)" => "(add (mul ?a ?c) (mul ?b ?c))" := + apply add_mul; + +register_rule_to_tactic "add_mul-rev" ; "(add (mul ?a ?c) (mul ?b ?c))" => "(mul (add ?a ?b) ?c)" := + apply (add_mul _ _ _).symm; + +register_rule_to_tactic "mul_add" ; "(mul ?a (add ?b ?c))" => "(add (mul ?a ?b) (mul ?a ?c))" := + apply mul_add; + +register_rule_to_tactic "mul_add-rev" ; "(add (mul ?a ?b) (mul ?a ?c))" => "(mul ?a (add ?b ?c))" := + apply (mul_add _ _ _).symm; + +register_rule_to_tactic "mul_sub" ; "(mul ?a (sub ?b ?c))" => "(sub (mul ?a ?b) (mul ?a ?c))" := + apply mul_sub; + +register_rule_to_tactic "mul_sub-rev" ; "(sub (mul ?a ?b) (mul ?a ?c))" => "(mul ?a (sub ?b ?c))" := + apply (mul_sub _ _ _).symm; + +register_rule_to_tactic "div_one" ; "(div ?a 1)" => "?a" := + apply div_one; + +register_rule_to_tactic "add_div" ; "(div (add ?a ?b) ?c)" => "(add (div ?a ?c) (div ?b ?c))" := + apply add_div; + +register_rule_to_tactic "add_div-rev" ; "(add (div ?a ?c) (div ?b ?c))" => "(div (add ?a ?b) ?c)" := + apply (add_div _ _ _).symm; + +register_rule_to_tactic "mul_div" ; "(mul ?a (div ?b ?c))" => "(div (mul ?a ?b) ?c)" := + apply mul_div (G := ℝ); + +register_rule_to_tactic "mul_div-rev" ; "(div (mul ?a ?b) ?c)" => "(mul ?a (div ?b ?c))" := + apply (mul_div (G := ℝ) _ _ _).symm; + +register_rule_to_tactic "div_mul" ; "(mul (div ?a ?b) ?c)" => "(div ?a (div ?b ?c))" := + apply div_mul; + +register_rule_to_tactic "div_mul-rev" ; "(div ?a (div ?b ?c))" => "(mul (div ?a ?b) ?c)" := + apply (div_mul _ _ _).symm; + +register_rule_to_tactic "mul_div_mul_left" ; "(div (mul ?c ?a) (mul ?c ?b))" => "(div ?a ?b)" := + apply mul_div_mul_left (G₀ := ℝ) _ _ (by positivity!); + +register_rule_to_tactic "div_div" ; "(div (div ?a ?b) ?c)" => "(div ?a (mul ?b ?c))" := + apply div_div; + +register_rule_to_tactic "div_div-rev" ; "(div ?a (mul ?b ?c))" => "(div (div ?a ?b) ?c)" := + apply (div_div _ _ _).symm; + +register_rule_to_tactic "div_self" ; "(div ?a ?a)" => "1" := + apply div_self (G₀ := ℝ) (by positivity!); + +register_rule_to_tactic "inv_eq_one_div" ; "(inv ?a)" => "(div 1 ?a)" := + apply inv_eq_one_div (G := ℝ); + +register_rule_to_tactic "inv_inv" ; "(inv (inv ?a))" => "?a" := + apply inv_inv; + + +/- Power and square root rules. -/ + +register_rule_to_tactic "one_pow"; "(pow 1 ?a)" => "1" := + apply Real.one_rpow; + +register_rule_to_tactic "pow_one"; "(pow ?a 1)" => "?a" := + apply Real.rpow_one; + +register_rule_to_tactic "pow_add"; "(pow ?a (add ?b ?c))" => "(mul (pow ?a ?b) (pow ?a ?c))" := + apply Real.rpow_add (by positivity!); + +register_rule_to_tactic "pow_add-rev"; "(mul (pow ?a ?b) (pow ?a ?c))" => "(pow ?a (add ?b ?c))" := + apply (Real.rpow_add (by positivity!) _ _).symm; + +register_rule_to_tactic "mul_pow"; "(mul (pow ?a ?n) (pow ?b ?n))" => "(pow (mul ?a ?b) ?n)" := + apply Real.mul_rpow (by positivity!) (by positivity!); + +register_rule_to_tactic "mul_pow-rev"; "(mul (pow ?a ?n) (pow ?b ?n))" => "(pow (mul ?a ?b) ?n)" := + apply (Real.mul_rpow (by positivity!) (by positivity!)).symm; + +register_rule_to_tactic "pow_mul"; "(pow ?a (mul ?n ?m))" => "(pow (pow ?a ?n) ?m)" := + apply Real.rpow_mul (by positivity!); + +register_rule_to_tactic "pow_mul-rev"; "(pow (pow ?a ?n) ?m)" => "(pow ?a (mul ?n ?m))" := + apply (Real.rpow_mul (by positivity!) _ _).symm; + +register_rule_to_tactic "div_pow"; "(pow (div ?a ?b) ?n)" => "(div (pow ?a ?n) (pow ?b ?n))" := + apply Real.div_rpow (by positivity!) (by positivity!); + +register_rule_to_tactic "div_pow-rev"; "(div (pow ?a ?n) (pow ?b ?n))" => "(pow (div ?a ?b) ?n)" := + apply (Real.div_rpow (by positivity!) (by positivity!) _).symm; + +register_rule_to_tactic "pow_sub"; "(pow ?a (sub ?b ?c))" => "(div (pow ?a ?b) (pow ?a ?c))" := + apply Real.rpow_sub (by positivity!); + +register_rule_to_tactic "pow_sub-rev"; "(div (pow ?a ?b) (pow ?a ?c))" => "(pow ?a (sub ?b ?c))" := + apply (Real.rpow_sub (by positivity!) _ _).symm; + +register_rule_to_tactic + "div_pow_eq_mul_pow_neg" ; "(div ?a (pow ?b ?c))" => "(mul ?a (pow ?b (neg ?c)))" := + apply Real.div_pow_eq_mul_pow_neg (by positivity!); + +register_rule_to_tactic + "div_pow_eq_mul_pow_neg-rev" ; "(div ?a (pow ?b ?c))" => "(mul ?a (pow ?b (neg ?c)))" := + apply (Real.div_pow_eq_mul_pow_neg (by positivity!)).symm; + +register_rule_to_tactic "one_div_eq_pow_neg_one"; "(div 1 ?a)" <=> "(pow ?a (neg 1))" := + apply Real.one_div_eq_pow_neg_one (by positivity!); + +register_rule_to_tactic "sqrt_eq_rpow" ; "(sqrt ?a)" => "(pow ?a 0.5)" := + apply Real.sqrt_eq_rpow; + +register_rule_to_tactic "sqrt_eq_rpow-rev" ; "(pow ?a 0.5)" => "(sqrt ?a)" := + apply (Real.sqrt_eq_rpow _).symm; + +register_rule_to_tactic "sqrt_mul" ; "(sqrt (mul ?a ?b))" => "(mul (sqrt ?a) (sqrt ?b))" := + apply Real.sqrt_mul (by positivity!); + +register_rule_to_tactic "sqrt_mul-rev" ; "(mul (sqrt ?a) (sqrt ?b))" => "(sqrt (mul ?a ?b))" := + apply (Real.sqrt_mul (by positivity!) _).symm; + +register_rule_to_tactic "pow_two"; "(pow ?a 2)" => "(mul ?a ?a)" := + (norm_cast; apply pow_two); + +register_rule_to_tactic "pow_two-rev"; "(mul ?a ?a)" => "(pow ?a 2)" := + (norm_cast; apply (pow_two _).symm); + +register_rule_to_tactic "pow_half_two"; "(pow (pow ?a 0.5) 2)" => "?a" := + apply Real.pow_half_two (by positivity!); + +register_rule_to_tactic "pow_half_two-rev"; "?a" => "(pow (pow ?a 0.5) 2)" := + apply (Real.pow_half_two (by positivity!)).symm; + +register_rule_to_tactic "binomial_two"; + "(pow (add ?a ?b) 2)" => "(add (pow ?a 2) (add (mul 2 (mul ?a ?b)) (pow ?b 2)))" := + apply Real.binomial_two; + +register_rule_to_tactic "rpow_eq_mul_rpow_pred"; "(pow ?a ?b)" => "(mul ?a (pow ?a (sub ?b 1)))" := + apply Real.rpow_eq_mul_rpow_pred (by positivity!); + +register_rule_to_tactic "inv_eq_pow_neg_one"; "(inv ?a)" => "(pow ?a (neg 1))" := + apply Real.inv_eq_pow_neg_one (by positivity!); + + +/- Exponential and logarithm rules. -/ + +register_rule_to_tactic "exp_add" ; "(exp (add ?a ?b))" => "(mul (exp ?a) (exp ?b))" := + apply Real.exp_add; + +register_rule_to_tactic "exp_add-rev" ; "(mul (exp ?a) (exp ?b))" => "(exp (add ?a ?b))" := + apply (Real.exp_add _ _).symm; + +register_rule_to_tactic "exp_sub" ; "(exp (sub ?a ?b))" => "(div (exp ?a) (exp ?b))" := + apply Real.exp_sub; + +register_rule_to_tactic "exp_sub-rev" ; "(div (exp ?a) (exp ?b))" => "(exp (sub ?a ?b))" := + apply (Real.exp_sub _ _).symm; + +register_rule_to_tactic "exp_mul" ; "(exp (mul ?a ?b))" => "(pow (exp ?a) ?b)" := + apply Real.exp_mul; + +register_rule_to_tactic "exp_mul-rev" ; "(pow (exp ?a) ?b)" => "(exp (mul ?a ?b))" := + apply (Real.exp_mul _ _).symm; + +register_rule_to_tactic "exp_neg_eq_one_div"; "(exp (neg ?a))" => "(div 1 (exp ?a))" := + apply Real.exp_neg_eq_one_div; + +register_rule_to_tactic "exp_neg_eq_one_div-rev" ; "(div 1 (exp ?a))" => "(exp (neg ?a))" := + apply (Real.exp_neg_eq_one_div _).symm; + +register_rule_to_tactic "log_mul" ; "(log (mul ?a ?b))" => "(add (log ?a) (log ?b))" := + apply Real.log_mul (by positivity!) (by positivity!); + +register_rule_to_tactic "log_mul-rev"; "(add (log ?a) (log ?b))" => "(log (mul ?a ?b))" := + apply (Real.log_mul (by positivity!) (by positivity!)).symm; + +register_rule_to_tactic "log_div" ; "(log (div ?a ?b))" => "(sub (log ?a) (log ?b))" := + apply Real.log_div (by positivity!) (by positivity!); + +register_rule_to_tactic "log_div-rev"; "(sub (log ?a) (log ?b))" => "(log (div ?a ?b))" := + apply (Real.log_div (by positivity!) (by positivity!)).symm; + +register_rule_to_tactic "log_pow" ; "(log (pow ?a ?b))" => "(mul ?b (log ?a))" := + apply (Real.log_rpow (by positivity!)); + +register_rule_to_tactic "log_pow-rev" ; "(mul ?b (log ?a))" => "(log (pow ?a ?b))" := + apply (Real.log_rpow (by positivity!) _).symm; + +-- Special type of rule as it only works if the exponent is a natural number. +register_rule_to_tactic "log_pow_nat" ; "(log (pow ?a ?b))" => "(mul ?b (log ?a))" := + (norm_cast; apply Real.log_pow); + +register_rule_to_tactic "log_pow_nat-rev" ; "(mul ?b (log ?a))" => "(log (pow ?a ?b))" := + (norm_cast; apply (Real.log_pow _ _).symm); + +register_rule_to_tactic "exp_log" ; "(exp (log ?a))" => "?a" := + apply Real.exp_log (by positivity!); + +register_rule_to_tactic "log_exp" ; "(log (exp ?a))" => "?a" := + apply Real.log_exp; + + +/- Absolute value rules. -/ + +register_rule_to_tactic "abs_nonneg" ; "(abs ?a)" => "?a" := + apply abs_eq_self.mpr (by positivity!); + +register_rule_to_tactic "abs_nonpos" ; "(abs ?a)" => "(neg ?a)" := + apply abs_eq_neg_self.mpr (by linarith); + + +/- Min and max rules. -/ + +register_rule_to_tactic "min_eq_left"; "(min ?a ?b)" => "?a" := + apply min_eq_left (by linarith); + +register_rule_to_tactic "min_eq_right"; "(min ?a ?b)" => "?b" := + apply min_eq_right (by linarith); + +register_rule_to_tactic "max_eq_left"; "(max ?a ?b)" => "?a" := + apply max_eq_left (by linarith); + +register_rule_to_tactic "max_eq_right"; "(max ?a ?b)" => "?b" := + apply max_eq_right (by linarith); + + +/- Atom folding. -/ + +register_rule_to_tactic "xexp_fold"; "(mul ?a (exp ?a))" => "(xexp ?a)" := + rfl; + +register_rule_to_tactic "xexp_unfold"; "(xexp ?a)" => "(mul ?a (exp ?a))" := + rfl; + +register_rule_to_tactic "entr_fold"; "(neg (mul ?a (log ?a)))" => "(entr ?a)" := + rfl; + +register_rule_to_tactic "entr_unfold"; "(entr ?a)" => "(neg (mul ?a (log ?a)))" := + rfl; + +register_rule_to_tactic "qol_fold"; "(div (pow ?a 2) ?b)" => "(qol ?a ?b)" := + rfl; + +register_rule_to_tactic "qol_unfold"; "(qol ?a ?b)" => "(div (pow ?a 2) ?b)" := + rfl; + +register_rule_to_tactic "geo_fold"; "(sqrt (mul ?a ?b))" => "(geo ?a ?b)" := + rfl; + +register_rule_to_tactic "geo_unfold"; "(geo ?a ?b)" => "(sqrt (mul ?a ?b))" := + rfl; + +register_rule_to_tactic "lse_fold"; "(log (add (exp ?a) (exp ?b)))" => "(lse ?a ?b)" := + rfl; + +register_rule_to_tactic "lse_unfold"; "(lse ?a ?b)" => "(log (add (exp ?a) (exp ?b)))" := + rfl; + +register_rule_to_tactic "norm2_fold"; "(sqrt (add (pow ?a 2) (pow ?b 2)))" => "(norm2 ?a ?b)" := + rfl; + +register_rule_to_tactic "norm2_unfold"; "(norm2 ?a ?b)" => "(sqrt (add (pow ?a 2) (pow ?b 2)))" := + rfl; + +end CvxLean diff --git a/CvxLean/Test/BasicTactics/Other.lean b/CvxLean/Test/BasicTactics/Other.lean index 3a8eb80a..cd23ef7a 100644 --- a/CvxLean/Test/BasicTactics/Other.lean +++ b/CvxLean/Test/BasicTactics/Other.lean @@ -79,7 +79,7 @@ example : Solution <| cz : z = x + y cz' : z = x + y cz'' : z ≤ x + 4 := by - remove_constr cz' by { exact cz } + remove_constr cz' => { exact cz } rename_constrs [czz, czz2] sorry @@ -121,6 +121,23 @@ equivalence eqv2/q1 : ring_nf rfl +def p := + optimization (x y : ℝ) + maximize (2 * x : ℝ) + subject to + c₁ : 0 ≤ x + c₂ : 1 < y + c₃ : log (y - 1) ≤ 2 * sqrt x + 1 + c₄ : 3 * x + 5 * y ≤ 10 + +equivalence' eqv3/q3 : p := by + change_of_variables! (v) (y ↦ v + 1) + change_of_variables! (w) (v ↦ exp w) + remove_constr c₂ => + field_simp; arith + rw_constr c₃ into (w ≤ 2 * sqrt x + 1) => + field_simp + end end BasicTacticTest diff --git a/CvxLean/Test/DCP/Basic.lean b/CvxLean/Test/DCP/Basic.lean index d283a6d0..3e0f51bc 100644 --- a/CvxLean/Test/DCP/Basic.lean +++ b/CvxLean/Test/DCP/Basic.lean @@ -114,6 +114,15 @@ def testLogDet {n : ℕ} : Solution <| dcp sorry +open BigOperators in +def testBinders {n} : Solution <| + optimization (x : Fin n → ℝ) + minimize ∑ i, x i + subject to + c1 : ∀ i, 1 ≤ x i := by + dcp + sorry + end Atoms section Misc diff --git a/CvxLean/Test/PreDCP/All.lean b/CvxLean/Test/PreDCP/All.lean index d55777e3..1ac62d2b 100644 --- a/CvxLean/Test/PreDCP/All.lean +++ b/CvxLean/Test/PreDCP/All.lean @@ -4,3 +4,4 @@ import CvxLean.Test.PreDCP.DGP import CvxLean.Test.PreDCP.DQCP import CvxLean.Test.PreDCP.MainExample import CvxLean.Test.PreDCP.Quiz +import CvxLean.Test.PreDCP.Stanford diff --git a/CvxLean/Test/PreDCP/DGP.lean b/CvxLean/Test/PreDCP/DGP.lean index eb5c4542..a1485dce 100644 --- a/CvxLean/Test/PreDCP/DGP.lean +++ b/CvxLean/Test/PreDCP/DGP.lean @@ -130,7 +130,7 @@ def gp5 := h3 : 0 < z h4 : 2 ≤ x h5 : x ≤ 3 - h6 : x ^ (2 : ℝ) + 3 * y / z ≤ sqrt x + h6 : x ^ (2 : ℝ) + 3 * y / z ≤ sqrt y h7 : x / y = z ^ (2 : ℝ) time_cmd reduction red5/dcp5 : gp5 := by @@ -140,24 +140,58 @@ time_cmd reduction red5/dcp5 : gp5 := by pre_dcp #print dcp5 --- optimization (x : ℝ) (y : ℝ) (z : ℝ) --- minimize -(x - y) +-- optimization (u : ℝ) (v : ℝ) (w : ℝ) +-- minimize v - u -- subject to --- h4 : log 2 ≤ x --- h5 : x ≤ log 3 --- h6 : exp (x * (3 / 2)) + 3 * exp (y - z - x * (1 / 2)) ≤ 1 --- h7 : x - y = 2 * z +-- h4 : log 2 ≤ u +-- h5 : u ≤ log 3 +-- h6 : rexp (u * 2 - v * (1 / 2)) ≤ 1 - 3 * rexp (v - w - v * (1 / 2)) +-- h7 : u - v = 2 * w solve dcp5 end GP5 +section GP6 + +-- NOTE: `maximize` does not work because it is set to `Neg.neg`. +def gp6 := + optimization (x y z : ℝ) + minimize 1 / (x / y) + subject to + h1 : 0 < x + h2 : 0 < y + h3 : 0 < z + h4 : 2 ≤ x + h5 : x ≤ 3 + h6 : x ^ (2 : ℝ) + 3 * y / z ≤ 5 * sqrt y + h7 : x * y = z ^ (2 : ℝ) + +time_cmd reduction red6/dcp6 : gp6 := by + change_of_variables! (u) (x ↦ exp u) + change_of_variables! (v) (y ↦ exp v) + change_of_variables! (w) (z ↦ exp w) + pre_dcp + +#print dcp6 +-- optimization (u : ℝ) (v : ℝ) (w : ℝ) +-- minimize v - u +-- subject to +-- h4 : log 2 ≤ u +-- h5 : u ≤ log 3 +-- h6 : rexp (u * 2 - v * (1 / 2)) + 3 * rexp (v * (1 / 2) - w) ≤ 5 +-- h7 : u + v = 2 * w + +--solve dcp6 + +end GP6 + /- In https://web.stanford.edu/~boyd/papers/pdf/gp_tutorial.pdf (5) and in https://www.cvxpy.org/examples/dgp/max_volume_box.html -/ -section GP6 +section GP7 -- NOTE: `maximize` issue. -def gp6 := +def gp7 := optimization (h w d : ℝ) minimize (1 / (h * w * d)) subject to @@ -171,13 +205,13 @@ def gp6 := h8 : 5 ≤ d / w h9 : d / w ≤ 6 -time_cmd reduction red6/dcp6 : gp6 := by +time_cmd reduction red7/dcp7 : gp7 := by change_of_variables! (h') (h ↦ exp h') change_of_variables! (w') (w ↦ exp w') change_of_variables! (d') (d ↦ exp d') pre_dcp -#print dcp6 +#print dcp7 -- optimization (h : ℝ) (w : ℝ) (d : ℝ) -- minimize -(h + (d + w)) -- subject to @@ -188,16 +222,16 @@ time_cmd reduction red6/dcp6 : gp6 := by -- h8 : log 5 ≤ d - w -- h9 : d - w ≤ log 6 -solve dcp6 +solve dcp7 -end GP6 +end GP7 /- In https://web.stanford.edu/~boyd/papers/pdf/gp_tutorial.pdf section 2.2. -/ -section GP7 +section GP8 -- objFun : (x ^ (-1)) * y ^ (-1 / 2) * z ^ (-1) + 2.3 * x * z + 4 * x * y * z -- h4 : (1 / 3) * x ^ (-2) * y ^ (-2) + (4 / 3) * y ^ (1 / 2) * z ^ (-1) ≤ 1 -def gp7 := +def gp8 := optimization (x y z : ℝ) minimize (1 / x) * (1 / sqrt y) * (1 / z) + (2.3) * x * z + 4 * x * y * z subject to @@ -208,13 +242,13 @@ def gp7 := h5 : x + 2 * y + 3 * z ≤ 1 h6 : (1 / 2) * x * y = 1 -time_cmd reduction red7/dcp7 : gp7 := by +time_cmd reduction red8/dcp8 : gp8 := by change_of_variables! (u) (x ↦ exp u) change_of_variables! (v) (y ↦ exp v) change_of_variables! (w) (z ↦ exp w) pre_dcp -#print dcp7 +#print dcp8 -- optimization (u : ℝ) (v : ℝ) (w : ℝ) -- minimize rexp (-(w + (u + v * (1 / 2)))) + (23 / 10 * rexp (u + w) + 4 * rexp (u + (v + w))) -- subject to @@ -222,15 +256,15 @@ time_cmd reduction red7/dcp7 : gp7 := by -- h5 : rexp v * 2 ≤ 1 - rexp u - rexp w * 3 -- h6 : u + (v + log (1 / 2)) = 0 -solve dcp7 +solve dcp8 -end GP7 +end GP8 /- In https://web.stanford.edu/~boyd/papers/pdf/gp_tutorial.pdf section 6.1. -/ -section GP8 +section GP9 -- hmin = wmin = 1, hmax = wmax = 100, Rmax = 10, σ = 0.5, π ≈ 3.14159. -def gp8 := +def gp9 := optimization (h w A r : ℝ) minimize 2 * A * sqrt (w ^ (2 : ℝ) + h ^ (2 : ℝ)) subject to @@ -247,16 +281,16 @@ def gp8 := h11 : 1.1 * r ≤ sqrt (A / (2 * 3.14159) + r ^ (2 : ℝ)) h12 : sqrt (A / (2 * 3.14159) + r ^ (2 : ℝ)) ≤ 10 -time_cmd reduction red8/dcp8 : gp8 := by +time_cmd reduction red9/dcp9 : gp9 := by change_of_variables! (h') (h ↦ exp h') change_of_variables! (w') (w ↦ exp w') change_of_variables! (A') (A ↦ exp A') change_of_variables! (r') (r ↦ exp r') pre_dcp -solve dcp8 +solve dcp9 -end GP8 +end GP9 end diff --git a/CvxLean/Test/PreDCP/DQCP.lean b/CvxLean/Test/PreDCP/DQCP.lean index 3baedd37..0ffa868a 100644 --- a/CvxLean/Test/PreDCP/DQCP.lean +++ b/CvxLean/Test/PreDCP/DQCP.lean @@ -37,14 +37,18 @@ time_cmd reduction redqcp1/dqcp1 : qcp1 := by solve dqcp1 -#eval dqcp1.value -- -2.000000 +#print dqcp1.reduced -- ... +#eval dqcp1.value -- -2.000000 +#eval dqcp1.solution -- (2.000000, 2.000000) end QCP1 -/- https://www.cvxpy.org/examples/dqcp/hypersonic_shape_design.html -/ section QCP2 -def hypersonicShapeDesign (a b : ℝ) := +/-! Adapted from https://www.cvxpy.org/examples/dqcp/hypersonic_shape_design.html. +See `CvxLean/Examples/HypersonicShapeDesign.lean` for a complete analysis. -/ + +def qcp2 (a b : ℝ) := optimization (Δx : ℝ) minimize sqrt (1 / (Δx ^ (2 : ℝ)) - 1) subject to @@ -52,32 +56,59 @@ def hypersonicShapeDesign (a b : ℝ) := h2 : Δx ≤ 1 h3 : a * (1 / Δx) - (1 - b) * sqrt (1 - Δx ^ (2 : ℝ)) ≤ 0 -time_cmd equivalence redqcp2/dqcp2 : hypersonicShapeDesign 0.05 0.65 := by +time_cmd equivalence redqcp2/dqcp2 : qcp2 0.05 0.65 := by pre_dcp solve dqcp2 -#print dqcp2.reduced --- optimization (Δx : ℝ) (t₀.0 : ℝ) (t₁.1 : ℝ) (t.2 : ℝ) (y'.3 : ℝ) (t.4 : ℝ) (t.5 : ℝ) --- minimize t₀.0 - 1 --- subject to --- _ : posOrthCone (1 - Δx) --- _ : posOrthCone (7 / 20 * t.5 - t.2 / 20) --- _ : soCone (Δx + t₁.1) ![Δx - t₁.1, 2] --- _ : soCone (t₀.0 + 1) ![t₀.0 - 1, 2 * t₁.1] --- _ : posOrthCone (t₀.0 - 0) --- _ : posOrthCone (t₁.1 - 0) --- _ : posOrthCone (Δx - 0) --- _ : soCone (Δx + t.2) ![Δx - t.2, 2 * 1] --- _ : posOrthCone (t.2 - 0) --- _ : rotatedSoCone t.4 (1 / 2) ![Δx] --- _ : rotatedSoCone (1 - t.4) (1 / 2) ![t.5] - -#eval dqcp2.value -#eval dqcp2.solution +#print dqcp2.reduced -- ... +#eval dqcp2.value -- 0.021286 +#eval dqcp2.solution -- 0.989524 end QCP2 +section QCP3 + +def qcp3 := + optimization (x y : ℝ) + minimize (0 : ℝ) + subject to + c₁ : 0.001 <= y + c₂ : x / y ≤ 3 + c₃ : x = 12 + c₄ : y ≤ 6 + +time_cmd reduction redqcp3/dqcp3 : qcp3 := by + pre_dcp + +solve dqcp3 + +#print dqcp3.reduced -- ... +#eval dqcp3.value -- 0.000000 +#eval dqcp3.solution -- (12.000000, 4.000000) + +end QCP3 + +section QCP4 + +def qcp4 := + optimization (x : ℝ) + minimize -x + subject to + c₁ : sqrt ((x + 1) ^ 2 + 4) / sqrt (x ^ 2 + 100) ≤ 1 + c₂ : 10 ≤ x + +time_cmd reduction redqcp4/dqcp4 : qcp4 := by + pre_dcp + +solve dqcp4 + +#print dqcp4.reduced -- ... +#eval dqcp4.value -- -47.500000 +#eval dqcp4.solution -- 47.500000 + +end QCP4 + end end DQCP diff --git a/CvxLean/Test/PreDCP/Stanford.lean b/CvxLean/Test/PreDCP/Stanford.lean index cc135f4a..8f947107 100644 --- a/CvxLean/Test/PreDCP/Stanford.lean +++ b/CvxLean/Test/PreDCP/Stanford.lean @@ -5,18 +5,74 @@ import CvxLean.Command.Util.TimeCmd import CvxLean.Tactic.Basic.ChangeOfVariables import CvxLean.Tactic.PreDCP.PreDCP - namespace Stanford noncomputable section open CvxLean Minimization Real -/- Exercise 3.67. -/ -section E367 +/- Exercise 3.36 (c). -/ +section Stan1 + +def stan1 := + optimization (x y : ℝ) + minimize log (exp (2 * x + 3) + exp (4 * y + 5)) + subject to + h1 : 0 ≤ x + +time_cmd reduction red1/dcp1 : stan1 := by + pre_dcp + +#print dcp1 + +solve dcp1 + +end Stan1 + +/- Exercise 3.38 (e), which is exercise 3.67 with n = 2. -/ +section Stan2 + +def stan2 := + optimization (x1 x2 : ℝ) + minimize (-(sqrt x1 + sqrt x2) ^ (2 : ℝ)) + subject to + h1 : 0.001 ≤ x1 + h2 : 0.001 ≤ x2 + +time_cmd reduction red2/dcp2 : stan2 := by + pre_dcp + +#print dcp2 + +solve dcp2 + +end Stan2 + +/- Exercise 3.67 with n = 3. -/ +section Stan3 + +def stan3 := + optimization (x1 x2 x3 : ℝ) + minimize (-(sqrt x1 + sqrt x2 + sqrt x3) ^ (2 : ℝ)) + subject to + h1 : 0.001 ≤ x1 + h2 : 0.001 ≤ x2 + h3 : 0.001 ≤ x3 + +time_cmd reduction red3/dcp3 : stan3 := by + pre_dcp + +#print dcp3 + +solve dcp3 + +end Stan3 + +/- Exercise 3.67 with n = 4. -/ +section Stan4 -def e367 := - optimization (x1 x2 x3 x4: ℝ) +def stan4 := + optimization (x1 x2 x3 x4 : ℝ) minimize (-(sqrt x1 + sqrt x2 + sqrt x3 + sqrt x4) ^ (2 : ℝ)) subject to h1 : 0.001 ≤ x1 @@ -24,14 +80,14 @@ def e367 := h3 : 0.001 ≤ x3 h4 : 0.001 ≤ x4 -time_cmd reduction red367/dcp367 : e367 := by +time_cmd reduction red4/dcp4 : stan4 := by pre_dcp -#print dcp367 +#print dcp4 -solve dcp367 +solve dcp4 -end E367 +end Stan4 end diff --git a/CvxLean/Test/PreDCP/dqcp.py b/CvxLean/Test/PreDCP/dqcp.py index 0faf6a7f..306a7d20 100644 --- a/CvxLean/Test/PreDCP/dqcp.py +++ b/CvxLean/Test/PreDCP/dqcp.py @@ -1,19 +1,72 @@ import cvxpy as cp +import numpy as np -x = cp.Variable(pos=True) -y = cp.Variable(nonneg=True) +def dqcp1(): + x = cp.Variable(pos=True) + y = cp.Variable(nonneg=True) -p = cp.Problem( - cp.Minimize(-y), [ - 1 <= x, - x <= 2, - cp.sqrt((2 * y) / (x + y)) <= 1 - ]) + p = cp.Problem( + cp.Minimize(-y), [ + 1 <= x, + x <= 2, + cp.sqrt((2 * y) / (x + y)) <= 1 + ]) -assert(p.is_dqcp()) -assert(not p.is_dgp()) -assert(not p.is_dcp()) + assert(p.is_dqcp()) + assert(not p.is_dgp()) + assert(not p.is_dcp()) -p.solve(qcp=True) + p.solve(qcp=True) -print(p.value) +def dqcp2(): + a = .05 + b = .65 + + x = cp.Variable(pos=True) + + p = cp.Problem( + cp.Minimize(cp.sqrt(cp.inv_pos(cp.square(x)) - 1)), [ + a * cp.inv_pos(x) - (1 - b) * cp.sqrt(1 - cp.square(x)) <= 0 + ]) + + assert(p.is_dqcp()) + assert(not p.is_dgp()) + assert(not p.is_dcp()) + + p.solve(qcp=True) + +def dqcp3(): + x = cp.Variable() + y = cp.Variable(nonneg=True) + p = cp.Problem( + cp.Minimize(0), [ + x / y <= 3, + x == 12, + y <= 6 + ]) + + assert(p.is_dqcp()) + assert(not p.is_dgp()) + assert(not p.is_dcp()) + + p.solve(qcp=True) + +def dqcp4(): + x = cp.Variable() + p = cp.Problem( + cp.Minimize(-x), [ + cp.dist_ratio(cp.vstack([x, 1.0]), np.array([-1, -1]), np.array([0.0, -10.0])) <= 1, + 10 <= x + ]) + + assert(p.is_dqcp()) + assert(not p.is_dgp()) + assert(not p.is_dcp()) + + p.solve(qcp=True) + +if __name__ == "__main__": + dqcp1() + dqcp2() + dqcp3() + dqcp4() diff --git a/CvxLean/Test/Solve/RealToFloat.lean b/CvxLean/Test/Solve/RealToFloat.lean index b51c0145..8b7ad384 100644 --- a/CvxLean/Test/Solve/RealToFloat.lean +++ b/CvxLean/Test/Solve/RealToFloat.lean @@ -6,33 +6,34 @@ section RealToFloat -- Simple tests. -#realToFloat Real +#real_to_float Real -#realToFloat (8 : Real) +#real_to_float (8 : Real) -#realToFloat (0 : Real) + 1 +#real_to_float (0 : Real) + 1 -#realToFloat Real.exp 1 +#real_to_float Real.exp 1 -#realToFloat fun (p : Real × (Fin 2 → Real)) => Real.exp p.1 + p.2 1 ≤ 0 ∧ Real.exp p.1 + p.2 1 ≤ 0 +#real_to_float fun (p : Real × (Fin 2 → Real)) => + Real.exp p.1 + p.2 1 ≤ 0 ∧ Real.exp p.1 + p.2 1 ≤ 0 -#realToFloat (1 : Real) = 3 +#real_to_float (1 : Real) = 3 -#realToFloat (2 • (1 : Matrix (Fin 1) (Fin 1) Real)) +#real_to_float (2 • (1 : Matrix (Fin 1) (Fin 1) Real)) -- Convert whole minimization problem. -#realToFloat @Minimization.mk Real Real (fun (x : Real) => x) (fun (x : Real) => x <= 0) +#real_to_float @Minimization.mk Real Real (fun (x : Real) => x) (fun (x : Real) => x <= 0) -- Test optimization_param. @[optimization_param] noncomputable def A : ℝ := 1 -#realToFloat A +#real_to_float A noncomputable def B : ℝ := 1 -#realToFloat B +#real_to_float B end RealToFloat diff --git a/egg-pre-dcp/Cargo.lock b/egg-pre-dcp/Cargo.lock index 442a0235..5580cad6 100644 --- a/egg-pre-dcp/Cargo.lock +++ b/egg-pre-dcp/Cargo.lock @@ -4,9 +4,9 @@ version = 3 [[package]] name = "ahash" -version = "0.7.7" +version = "0.7.8" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5a824f2aa7e75a0c98c5a504fceb80649e9c35265d44525b5f94de4771a395cd" +checksum = "891477e0c6a8957309ee5c45a6368af3ae14bb510732d2684ffa19af310920f9" dependencies = [ "getrandom", "once_cell", @@ -99,9 +99,9 @@ dependencies = [ [[package]] name = "getrandom" -version = "0.2.10" +version = "0.2.12" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "be4136b2a15dd319360be1c07d9933517ccf0be8f16bf62a3bee4f0d618df427" +checksum = "190092ea657667030ac6a35e305e62fc4dd69fd98ac98631e5d3a2b1575a12b5" dependencies = [ "cfg-if", "libc", @@ -110,9 +110,9 @@ dependencies = [ [[package]] name = "gmp-mpfr-sys" -version = "1.6.1" +version = "1.6.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "19c5c67d8c29fe87e3266e691dd60948e6e4df4496c53355ef3551142945721b" +checksum = "362a6cc3cbe9f41aebe49c03b91aee8fa8fc69d32fb90533f6ed965a882e08e3" dependencies = [ "libc", "windows-sys", @@ -161,15 +161,21 @@ dependencies = [ [[package]] name = "itoa" -version = "1.0.9" +version = "1.0.10" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "af150ab688ff2122fcef229be89cb50dd66af9e01a4ff320cc137eecc9bacc38" +checksum = "b1a46d1a171d865aa5f83f92695765caa047a9b4cbae2cbf37dbd613a793fd4c" [[package]] name = "libc" -version = "0.2.149" +version = "0.2.153" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a08173bc88b7955d1b3145aa561539096c421ac8debde8cbc3612ec635fee29b" +checksum = "9c198f91728a82281a64e1f4f9eeb25d82cb32a5de251c6bd1b5154d63a8e7bd" + +[[package]] +name = "libm" +version = "0.2.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4ec2a862134d2a7d32d7983ddcdd1c4923530833c9f2ea1a44fc5fa473989058" [[package]] name = "log" @@ -179,18 +185,18 @@ checksum = "b5e6163cb8c49088c2c36f57875e58ccd8c87c7427f7fbd50ea6710b2f3f2e8f" [[package]] name = "num-traits" -version = "0.2.17" +version = "0.2.18" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "39e3200413f237f41ab11ad6d161bc7239c84dcb631773ccd7de3dfe4b5c267c" +checksum = "da0df0e5185db44f69b44f26786fe401b6c293d1907744beaa7fa62b2e5a517a" dependencies = [ "autocfg", ] [[package]] name = "once_cell" -version = "1.18.0" +version = "1.19.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "dd8b5dd2ae5ed71462c540258bedcb51965123ad7e7ccf4b9a8cafaa4a63576d" +checksum = "3fdb12b2476b595f9358c5161aa467c2438859caa136dec86c26fdd2efe17b92" [[package]] name = "ordered-float" @@ -209,18 +215,18 @@ checksum = "5b40af805b3121feab8a3c29f04d8ad262fa8e0561883e7653e024ae4479e6de" [[package]] name = "proc-macro2" -version = "1.0.69" +version = "1.0.78" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "134c189feb4956b20f6f547d2cf727d4c0fe06722b20a0eec87ed445a97f92da" +checksum = "e2422ad645d89c99f8f3e6b88a9fdeca7fabeac836b1002371c4367c8f984aae" dependencies = [ "unicode-ident", ] [[package]] name = "quote" -version = "1.0.33" +version = "1.0.35" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5267fca4496028628a95160fc423a33e8b2e6af8a5302579e322e4b520293cae" +checksum = "291ec9ab5efd934aaf503a6466c5d5251535d108ee747472c3977cc5acc868ef" dependencies = [ "proc-macro2", ] @@ -257,35 +263,36 @@ dependencies = [ [[package]] name = "rug" -version = "1.22.0" +version = "1.24.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5a118f8296241f6952511e8f9ab5b8a759c0d9ed902159b0aeed82d902e84ca6" +checksum = "76a82fd85950d103ad075f104d10c77d71640830c6a959a418380be380eaf7cd" dependencies = [ "az", "gmp-mpfr-sys", "libc", + "libm", ] [[package]] name = "ryu" -version = "1.0.15" +version = "1.0.16" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1ad4cc8da4ef723ed60bced201181d83791ad433213d8c24efffda1eec85d741" +checksum = "f98d2aa92eebf49b69786be48e4477826b256916e84a57ff2a4f21923b48eb4c" [[package]] name = "serde" -version = "1.0.190" +version = "1.0.196" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "91d3c334ca1ee894a2c6f6ad698fe8c435b76d504b13d436f0685d648d6d96f7" +checksum = "870026e60fa08c69f064aa766c10f10b1d62db9ccd4d0abb206472bee0ce3b32" dependencies = [ "serde_derive", ] [[package]] name = "serde_derive" -version = "1.0.190" +version = "1.0.196" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "67c5609f394e5c2bd7fc51efda478004ea80ef42fee983d5c67a65e34f32c0e3" +checksum = "33c85360c95e7d137454dc81d9a4ed2b8efd8fbe19cee57357b32b9771fccb67" dependencies = [ "proc-macro2", "quote", @@ -294,9 +301,9 @@ dependencies = [ [[package]] name = "serde_json" -version = "1.0.108" +version = "1.0.113" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3d1c7e3eac408d115102c4c24ad393e0821bb3a5df4d506a80f85f7a742a526b" +checksum = "69801b70b1c3dac963ecb03a364ba0ceda9cf60c71cfe475e99864759c8b8a79" dependencies = [ "itoa", "ryu", @@ -305,9 +312,9 @@ dependencies = [ [[package]] name = "smallvec" -version = "1.11.1" +version = "1.13.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "942b4a808e05215192e39f4ab80813e599068285906cc91aa64f923db842bd5a" +checksum = "e6ecd384b10a64542d77071bd64bd7b231f4ed5940fba55e98c3de13824cf3d7" [[package]] name = "symbol_table" @@ -328,9 +335,9 @@ checksum = "7c68d531d83ec6c531150584c42a4290911964d5f0d79132b193b67252a23b71" [[package]] name = "syn" -version = "2.0.38" +version = "2.0.48" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e96b79aaa137db8f61e26363a0c9b47d8b4ec75da28b7d1d614c2303e232408b" +checksum = "0f3531638e407dfc0814761abb7c00a5b54992b849452a0646b7f65c9f770f3f" dependencies = [ "proc-macro2", "quote", @@ -339,18 +346,18 @@ dependencies = [ [[package]] name = "thiserror" -version = "1.0.50" +version = "1.0.57" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f9a7210f5c9a7156bb50aa36aed4c95afb51df0df00713949448cf9e97d382d2" +checksum = "1e45bcbe8ed29775f228095caf2cd67af7a4ccf756ebff23a306bf3e8b47b24b" dependencies = [ "thiserror-impl", ] [[package]] name = "thiserror-impl" -version = "1.0.50" +version = "1.0.57" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "266b2e40bc00e5a6c09c3584011e08b06f123c00362c92b975ba9843aaaa14b8" +checksum = "a953cb265bef375dae3de6663da4d3804eee9682ea80d8e2542529b73c531c81" dependencies = [ "proc-macro2", "quote", @@ -386,9 +393,18 @@ checksum = "9c8d87e72b64a3b4db28d11ce29237c246188f4f51057d65a7eab63b7987e423" [[package]] name = "windows-sys" -version = "0.42.0" +version = "0.52.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "282be5f36a8ce781fad8c8ae18fa3f9beff57ec1b52cb3de0789201425d9a33d" +dependencies = [ + "windows-targets", +] + +[[package]] +name = "windows-targets" +version = "0.52.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5a3e1820f08b8513f676f7ab6c1f99ff312fb97b553d30ff4dd86f9f15728aa7" +checksum = "8a18201040b24831fbb9e4eb208f8892e1f50a37feb53cc7ff887feb8f50e7cd" dependencies = [ "windows_aarch64_gnullvm", "windows_aarch64_msvc", @@ -401,42 +417,42 @@ dependencies = [ [[package]] name = "windows_aarch64_gnullvm" -version = "0.42.2" +version = "0.52.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "597a5118570b68bc08d8d59125332c54f1ba9d9adeedeef5b99b02ba2b0698f8" +checksum = "cb7764e35d4db8a7921e09562a0304bf2f93e0a51bfccee0bd0bb0b666b015ea" [[package]] name = "windows_aarch64_msvc" -version = "0.42.2" +version = "0.52.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e08e8864a60f06ef0d0ff4ba04124db8b0fb3be5776a5cd47641e942e58c4d43" +checksum = "bbaa0368d4f1d2aaefc55b6fcfee13f41544ddf36801e793edbbfd7d7df075ef" [[package]] name = "windows_i686_gnu" -version = "0.42.2" +version = "0.52.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c61d927d8da41da96a81f029489353e68739737d3beca43145c8afec9a31a84f" +checksum = "a28637cb1fa3560a16915793afb20081aba2c92ee8af57b4d5f28e4b3e7df313" [[package]] name = "windows_i686_msvc" -version = "0.42.2" +version = "0.52.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "44d840b6ec649f480a41c8d80f9c65108b92d89345dd94027bfe06ac444d1060" +checksum = "ffe5e8e31046ce6230cc7215707b816e339ff4d4d67c65dffa206fd0f7aa7b9a" [[package]] name = "windows_x86_64_gnu" -version = "0.42.2" +version = "0.52.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8de912b8b8feb55c064867cf047dda097f92d51efad5b491dfb98f6bbb70cb36" +checksum = "3d6fa32db2bc4a2f5abeacf2b69f7992cd09dca97498da74a151a3132c26befd" [[package]] name = "windows_x86_64_gnullvm" -version = "0.42.2" +version = "0.52.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "26d41b46a36d453748aedef1486d5c7a85db22e56aff34643984ea85514e94a3" +checksum = "1a657e1e9d3f514745a572a6846d3c7aa7dbe1658c056ed9c3344c4109a6949e" [[package]] name = "windows_x86_64_msvc" -version = "0.42.2" +version = "0.52.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9aec5da331524158c6d1a4ac0ab1541149c0b9505fde06423b02f5ef0106b9f0" +checksum = "dff9641d1cd4be8d1a070daf9e3773c5f67e78b4d9d42263020c057706765c04" diff --git a/egg-pre-dcp/src/cost.rs b/egg-pre-dcp/src/cost.rs index 59189857..0a5426f1 100644 --- a/egg-pre-dcp/src/cost.rs +++ b/egg-pre-dcp/src/cost.rs @@ -1,3 +1,5 @@ +use std::cmp::Ordering; + use egg::{*}; use crate::domain; @@ -13,23 +15,51 @@ pub struct DCPCost<'a> { pub egraph: &'a optimization::EGraph, } +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct Cost { + pub curvature: Curvature, + pub num_vars: u32, + pub term_size: u32 +} + +impl PartialOrd for Cost { + fn partial_cmp(&self, other:&Cost) -> Option { + match self.curvature.partial_cmp(&other.curvature) { + Some(Ordering::Equal) => { + match self.num_vars.partial_cmp(&other.num_vars) { + Some(Ordering::Equal) => { + self.term_size.partial_cmp(&other.term_size) + } + x => x + } + } + // It may happen that, on a pass, two nodes in the same e-class are tagged with convex + // and concave curvatures. It must be the case that on a later pass, an affine node is + // found and the whole class is tagged as affine. However, we need to temporarily pick + // one. We tell `egg` to pick either by saying that both curvatures give the same cost. + None => Some(Ordering::Equal), + x => x + } + } +} + impl<'a> CostFunction for DCPCost<'a> { // Curvature + number of variables (with repetition) + term size. // In lexicographic order. - type Cost = (Curvature, u32, u32); + type Cost = Cost; // Curvature analysis corresponds exactly to the DCP rules. fn cost(&mut self, enode: &Optimization, mut costs: C) -> Self::Cost where C: FnMut(Id) -> Self::Cost { macro_rules! get_curvature { - ($i:expr) => { costs(*$i).0 } + ($i:expr) => { costs(*$i).curvature } } macro_rules! get_num_vars { - ($i:expr) => { costs(*$i).1 } + ($i:expr) => { costs(*$i).num_vars } } macro_rules! get_term_size { - ($i:expr) => { costs(*$i).2 } + ($i:expr) => { costs(*$i).term_size } } let get_domain = @@ -338,7 +368,11 @@ impl<'a> CostFunction for DCPCost<'a> { } } - return (curvature, num_vars, term_size); + return Cost { + curvature: curvature, + num_vars: num_vars, + term_size: term_size + }; } } diff --git a/egg-pre-dcp/src/domain.rs b/egg-pre-dcp/src/domain.rs index d085b72f..7d752a94 100644 --- a/egg-pre-dcp/src/domain.rs +++ b/egg-pre-dcp/src/domain.rs @@ -8,17 +8,19 @@ use rug::{Float, float::Round, ops::DivAssignRound, ops::PowAssignRound}; const F64_PREC: u32 = 53; -pub fn zero() -> Float { Float::with_val(F64_PREC, 0.0) } +pub fn make_float(f: f64) -> Float { Float::with_val(F64_PREC, f) } -pub fn neg_zero() -> Float { Float::with_val(F64_PREC, -0.0) } +pub fn zero() -> Float { make_float(0.0) } -pub fn one() -> Float { Float::with_val(F64_PREC, 1.0) } +pub fn neg_zero() -> Float { make_float(-0.0) } -pub fn neg_one() -> Float { Float::with_val(F64_PREC, -1.0) } +pub fn one() -> Float { make_float(1.0) } -pub fn inf() -> Float { Float::with_val(F64_PREC, f64::INFINITY) } +pub fn neg_one() -> Float { make_float(-1.0) } -pub fn neg_inf() -> Float { Float::with_val(F64_PREC, f64::NEG_INFINITY) } +pub fn inf() -> Float { make_float(f64::INFINITY) } + +pub fn neg_inf() -> Float { make_float(f64::NEG_INFINITY) } const NO_ERROR: ErrorInterval = ErrorInterval { lo: false, hi: false }; @@ -64,7 +66,7 @@ impl Domain { } pub fn make_singleton(f: f64) -> Domain { - let f_f = Float::with_val(F64_PREC, f); + let f_f = make_float(f); Domain::make_from_endpoints( f_f.clone(), f_f.clone(), false, false ) @@ -266,7 +268,7 @@ fn custom_string_to_float(s: String) -> Option { "-inf" => Some(neg_inf()), _ => { match s.parse::() { - Ok(f) => Some(Float::with_val(F64_PREC, f)), + Ok(f) => Some(make_float(f)), _ => None } } @@ -851,7 +853,7 @@ pub fn pow(d1: &Domain, d2: &Domain) -> Domain { let d = abs(d1); let lo = d.lo_float(); let hi = d.hi_float(); - let c = &Float::with_val(F64_PREC, f); + let c = &make_float(f); let l = d.lo_open; let r = d.hi_open; let res = diff --git a/egg-pre-dcp/src/explain_util.rs b/egg-pre-dcp/src/explain_util.rs index 9a260658..0517d05a 100644 --- a/egg-pre-dcp/src/explain_util.rs +++ b/egg-pre-dcp/src/explain_util.rs @@ -38,6 +38,21 @@ pub fn make_rule_table(rules: &Vec) -> HashMap { table } +// Remove rewrite labels and put a "?" in the subexpression to be rewritten. +pub fn expected_expr_with_hole(ft : &FlatTerm) -> FlatTerm { + if ft.forward_rule.is_some() || ft.backward_rule.is_some() { + FlatTerm::new(Optimization::Symbol("?".into()), vec![]) + } else { + FlatTerm::new( + ft.node.clone(), + ft.children + .iter() + .map(|child| expected_expr_with_hole(child)) + .collect() + ) + } +} + // These are functions taken from `explain.rs` removing the assertions. #[allow(unused)] diff --git a/egg-pre-dcp/src/extract.rs b/egg-pre-dcp/src/extract.rs index 1641b35c..b467f04f 100644 --- a/egg-pre-dcp/src/extract.rs +++ b/egg-pre-dcp/src/extract.rs @@ -3,6 +3,8 @@ use std::convert::TryInto; use std::fs; use std::time::Duration; use std::collections::HashMap; +use std::vec; +use std::time::Instant; use serde::{Deserialize, Serialize}; use crate::domain; @@ -18,12 +20,15 @@ use optimization::Meta as Meta; use crate::rules; use rules::rules as rules; +#[cfg(not(stop_on_success))] use crate::cost; +#[cfg(not(stop_on_success))] use cost::DCPCost as DCPCost; use crate::explain_util; use explain_util::Direction as Direction; use explain_util::get_rewrite_name_and_direction as get_rewrite_name_and_direction; +use explain_util::expected_expr_with_hole as expected_expr_with_hole; #[derive(Serialize, Debug)] pub struct Step { @@ -48,28 +53,29 @@ fn get_step_aux( Optimization::ObjFun(_) => { *location = Some("objFun".to_string()); - let o_s = next.children[0].get_recexpr().to_string(); + let o_s = expected_expr_with_hole(&next.children[0]).get_string(); *expected_term = Some(o_s); }, Optimization::Constr([_, _]) => { - let h_s = next.children[0].get_recexpr().to_string(); + let h_s = next.children[0].get_string(); *location = Some(h_s); - let c_s = next.children[1].get_recexpr().to_string(); + let c_s = expected_expr_with_hole(&next.children[1]).get_string(); *expected_term = Some(c_s); } _ => { if expected_term.is_none() { // Out-of-context extraction. Useful for testing. - let c_s = next.get_recexpr().to_string(); + let c_s = expected_expr_with_hole(&next).get_string(); *expected_term = Some(c_s); } } } if let Some(rule_name) = &next.backward_rule { - let subexpr_from = current.get_recexpr().to_string(); - let subexpr_to = next.get_recexpr().to_string(); + let subexpr_from = current.remove_rewrites().get_string(); + let subexpr_to = next.remove_rewrites().get_string(); + if expected_term.is_some() { return Some(Step { rewrite_name: rule_name.to_string(), @@ -85,8 +91,8 @@ fn get_step_aux( } if let Some(rule_name) = &next.forward_rule { - let subexpr_from = current.get_recexpr().to_string(); - let subexpr_to = next.get_recexpr().to_string(); + let subexpr_from = current.remove_rewrites().get_string(); + let subexpr_to = next.remove_rewrites().get_string(); if expected_term.is_some() { return Some(Step { rewrite_name: rule_name.to_string(), @@ -144,6 +150,11 @@ impl ToString for Minimization { // Return the rewrite steps if egg successfully found a chain of rewrites to // transform the term into DCP form. Return `None` if it didn't. +#[allow(unused)] +pub fn get_steps_maybe_node_limit(prob: Minimization, domains_vec: Vec<(String, Domain)>, debug: bool, node_limit: Option) -> Option> { + get_steps_from_string_maybe_node_limit(&prob.to_string(), domains_vec, debug, node_limit) +} + pub fn get_steps(prob: Minimization, domains_vec: Vec<(String, Domain)>, debug: bool) -> Option> { get_steps_from_string(&prob.to_string(), domains_vec, debug) } @@ -157,9 +168,11 @@ pub fn get_steps_from_string_maybe_node_limit( domains_vec: Vec<(String, Domain)>, debug: bool, node_limit: Option) -> Option> { + let starting_time: Instant = Instant::now(); let expr: RecExpr = prob_s.parse().unwrap(); // Process domains, intersecting domains assigned to the same variable. + let domains_len = domains_vec.len(); let mut domains: HashMap = HashMap::new(); for (x, dom) in domains_vec { if domains.contains_key(&x) { @@ -180,29 +193,53 @@ pub fn get_steps_from_string_maybe_node_limit( } } - // Specify a node limit, or try with default ones. + // Choose the specified node limit, or select the default ones (for `stop_on_success`, we set + // a large limit; note that if a DCP term is found it will not be reached). let node_limits = if let Some(n) = node_limit { vec![n] } else { - vec![2500, 5000, 10000, 20000, 50000, 80000] + if cfg!(stop_on_success) { + vec![100000] + } else { + vec![2500, 5000, 10000, 20000, 40000, 80000] + } }; + for node_limit in node_limits { let analysis = Meta { domains : domains.clone() }; + // Set up the runner with the given expression, analysis and limits. let iter_limit = node_limit / 250; let time_limit = (node_limit / 500).try_into().unwrap(); - let runner: Runner = + let mut runner: Runner = Runner::new(analysis) .with_explanations_enabled() .with_explanation_length_optimization() .with_node_limit(node_limit) .with_iter_limit(iter_limit) .with_time_limit(Duration::from_secs(time_limit)) - .with_expr(&expr) - .run(&rules()); + .with_expr(&expr); + + #[cfg(stop_on_success)] + { + runner = + runner + .with_hook(|runner| { + let data = runner.egraph[runner.roots[0]].data.clone(); + if data.curvature <= Curvature::Convex { + return Err("DCP term found.".to_string()); + } + return Ok(()); + }) + .run(&rules()) + } + #[cfg(not(stop_on_success))] + { + runner = runner.run(&rules()) + } if debug { println!("Creating graph with {:?} nodes.", runner.egraph.total_number_of_nodes()); @@ -212,46 +249,83 @@ pub fn get_steps_from_string_maybe_node_limit( let root = runner.roots[0]; - let best_cost; + // Extract the best term and best cost. This is obtained directly from the e-class + // analysis in the `stop_on_success` case, and by running the extractor otherwise. + let best_curvature; + let best_num_vars; + let best_term_size; let best; + #[cfg(stop_on_success)] + { + let result_data = runner.egraph[root].data.clone(); + best = result_data.best; + best_curvature = result_data.curvature; + best_num_vars = result_data.num_vars; + best_term_size = result_data.term_size; + } + #[cfg(not(stop_on_success))] { let cost_func = DCPCost { egraph: &runner.egraph }; - let extractor = - Extractor::new(&runner.egraph, cost_func); - let (best_cost_found, best_found) = - extractor.find_best(root); + let extractor = Extractor::new(&runner.egraph, cost_func); + let (best_cost_found, best_found) = extractor.find_best(root); best = best_found; - best_cost = best_cost_found; + best_curvature = best_cost_found.curvature; + best_num_vars = best_cost_found.num_vars; + best_term_size = best_cost_found.term_size; } - if debug { - println!("Best cost: {:?}", best_cost); - println!("Best: {:?}", best.to_string()); + + let curvature = best_curvature; + let num_vars = best_num_vars; + // Note: each domain constraint is an expression with 3 nodes, e.g. `0 <= x`. + let term_size = best_term_size + 3 * (domains_len as u32); + if debug && curvature <= Curvature::Convex { + let total_nodes = runner.egraph.total_number_of_nodes(); + println!("Succeeded with node limit {:?} (using {:?} nodes).", node_limit, total_nodes); + println!("Best curvature: {:?}.", curvature); + println!("Best number of variables: {:?}.", num_vars); + println!("Best term size: {:?}.", term_size); + println!("Best term: {:?}.", best.to_string()); } + + // If term is not DCP, try with the next node limit. + if curvature <= Curvature::Convex { + if debug { + let build_time = starting_time.elapsed().as_millis(); + println!("E-graph building time: {:.2?} ms.", build_time); + } + } else { + continue; + } + let after_build_time = Instant::now(); + // If term is DCP, extract the steps. let mut egraph = runner.egraph; let mut explanation : Explanation = egraph.explain_equivalence(&expr, &best); let flat_explanation : &FlatExplanation = explanation.make_flat_explanation(); if debug { - println!("{} steps", flat_explanation.len() - 1); + println!("Number of steps: {}.", flat_explanation.len() - 1); } let mut res = Vec::new(); - if best_cost.0 <= Curvature::Convex { - for i in 0..flat_explanation.len() - 1 { - let current = &flat_explanation[i]; - let next = &flat_explanation[i + 1]; - match get_step(current, next) { - Some(step) => { res.push(step); } - None => { - // Should not get here. - println!("Failed to extract step."); - } + for i in 0..flat_explanation.len() - 1 { + let current = &flat_explanation[i]; + let next = &flat_explanation[i + 1]; + match get_step(current, next) { + Some(step) => { res.push(step); } + None => { + // Should not get here. + println!("Failed to extract step."); } } - } else { - continue; + } + + if debug { + let extract_time = after_build_time.elapsed().as_millis(); + let total_time = starting_time.elapsed().as_millis(); + println!("Step extraction time: {:.2?} ms.", extract_time); + println!("Total time: {:.2?} ms.", total_time); } return Some(res); diff --git a/egg-pre-dcp/src/optimization.rs b/egg-pre-dcp/src/optimization.rs index 60decfc8..e717652f 100644 --- a/egg-pre-dcp/src/optimization.rs +++ b/egg-pre-dcp/src/optimization.rs @@ -5,6 +5,11 @@ use ordered_float::NotNan; use crate::domain; use domain::Domain as Domain; +#[cfg(stop_on_success)] +use crate::curvature; +#[cfg(stop_on_success)] +use curvature::Curvature as Curvature; + pub type Constant = NotNan; define_language! { @@ -49,12 +54,14 @@ pub struct Meta { pub domains : HashMap, } +#[cfg(not(stop_on_success))] #[derive(Debug, Clone)] pub struct Data { pub domain: Option, pub is_constant: bool, } +#[cfg(not(stop_on_success))] impl Analysis for Meta { type Data = Data; @@ -230,6 +237,477 @@ impl Analysis for Meta { } } +#[cfg(stop_on_success)] +#[derive(Debug, Clone)] +pub struct DataWithCost { + pub domain: Option, + pub is_constant: bool, + pub curvature: Curvature, + pub best: RecExpr, + pub num_vars: u32, + pub term_size: u32, +} + +#[cfg(stop_on_success)] +impl Analysis for Meta { + type Data = DataWithCost; + + fn merge(&mut self, to: &mut Self::Data, from: Self::Data) -> DidMerge { + let d_before_o = to.domain.clone(); + match (to.domain.clone(), from.domain.clone()) { + (None, Some(_)) => { to.domain = from.domain.clone(); } + (Some(d_to), Some(d_from)) => { + if !d_to.eq(&d_from) { + let inter = d_to.intersection(&d_from); + if Domain::is_empty(&inter) { + // Should never get here for feasible problems. + to.domain = None + } else { + to.domain = Some(inter); + } + } + } + _ => () + } + let to_domain_diff = + match (d_before_o, to.domain.clone()) { + (Some(d_before), Some(d_to)) => !d_before.eq(&d_to), + (None, None) => false, + _ => true + }; + let from_domain_diff = + match (to.domain.clone(), from.domain.clone()) { + (Some(d_to), Some(d_from)) => !d_to.eq(&d_from), + (None, None) => false, + _ => true + }; + + let is_constant_before = to.is_constant.clone(); + to.is_constant = to.is_constant || from.is_constant; + let to_is_constant_diff = is_constant_before != to.is_constant; + let from_is_constant_diff = to.is_constant != from.is_constant; + + let curvature_before = to.curvature.clone(); + let best_before = to.best.clone(); + let num_vars_before = to.num_vars.clone(); + let term_size_before = to.term_size.clone(); + if from.curvature < to.curvature { + to.curvature = from.curvature.clone(); + to.best = from.best.clone(); + to.num_vars = from.num_vars.clone(); + to.term_size = from.term_size.clone(); + } else if to.curvature == from.curvature { + // Select smallest "best". + if (from.num_vars, from.term_size) < (to.num_vars, to.term_size) { + to.best = from.best.clone(); + to.num_vars = from.num_vars.clone(); + to.term_size = from.term_size.clone(); + } + } + let to_curvature_diff = curvature_before != to.curvature; + let to_best_diff = best_before != to.best; + let to_num_vars_diff = num_vars_before != to.num_vars; + let to_term_size_diff = term_size_before != to.term_size; + let from_curvature_diff = to.curvature != from.curvature; + let from_best_diff = to.best != from.best; + let from_num_vars_diff = to.num_vars != from.num_vars; + let from_term_size_diff = to.term_size != from.term_size; + + DidMerge( + to_domain_diff || to_is_constant_diff || to_curvature_diff || to_best_diff || + to_num_vars_diff || to_term_size_diff, + from_domain_diff || from_is_constant_diff || from_curvature_diff || from_best_diff || + from_num_vars_diff || from_term_size_diff) + } + + fn make(egraph: &EGraph, enode: &Optimization) -> Self::Data { + let get_domain = + |i: &Id| egraph[*i].data.domain.clone(); + let get_is_constant = + |i: &Id| egraph[*i].data.is_constant.clone(); + let get_curvature = + |i: &Id| egraph[*i].data.curvature.clone(); + let get_best = + |i: &Id| egraph[*i].data.best.clone(); + let get_num_vars = + |i: &Id| egraph[*i].data.num_vars.clone(); + let get_term_size = + |i: &Id| egraph[*i].data.term_size.clone(); + + let domains_map = + egraph.analysis.domains.clone(); + + let mut domain = None; + let mut is_constant = false; + let mut curvature = Curvature::Unknown; + let mut best = RecExpr::default(); + let mut num_vars = 0; + let mut term_size = 0; + + match enode { + Optimization::Prob([a, b]) => { + curvature = + if get_curvature(a) >= get_curvature(b) { + get_curvature(a) + } else if get_curvature(b) >= get_curvature(a) { + get_curvature(b) + } else { + // Should not get here. + Curvature::Unknown + }; + best = format!("(prob {} {})", get_best(a), get_best(b)).parse().unwrap(); + num_vars = get_num_vars(a) + get_num_vars(b); + term_size = 1 + get_term_size(a) + get_term_size(b); + } + Optimization::ObjFun(a) => { + // It cannot be concave, because of mapping functions. + curvature = + if get_curvature(a) <= Curvature::Convex { + get_curvature(a) + } else { + Curvature::Unknown + }; + best = format!("(objFun {})", get_best(a)).parse().unwrap(); + num_vars = get_num_vars(a); + term_size = 1 + get_term_size(a); + } + Optimization::Constr([h, c]) => { + // It cannot be concave, because the notion of concavity at the Prop (or set) level + // is not well-defined. + curvature = + if get_curvature(c) <= Curvature::Convex { + get_curvature(c) + } else { + Curvature::Unknown + }; + best = format!("(constr {} {})", get_best(h), get_best(c)).parse().unwrap(); + num_vars = get_num_vars(c); + term_size = 1 + get_term_size(c); + } + Optimization::Constrs(a) => { + curvature = Curvature::Constant; + term_size = 0; + num_vars = 0; + for c in a.iter() { + if curvature < get_curvature(c) { + curvature = get_curvature(c); + } + num_vars += get_num_vars(c); + term_size += get_term_size(c); + } + let constrs_s_l : Vec = + a.iter().map(|c| format!("{}", get_best(c))).collect(); + best = format!("(constrs {})", constrs_s_l.join(" ")).parse().unwrap(); + } + Optimization::Eq([a, b]) => { + if get_curvature(a) <= Curvature::Affine && get_curvature(b) <= Curvature::Affine { + curvature = Curvature::Affine + } + best = format!("(eq {} {})", get_best(a), get_best(b)).parse().unwrap(); + num_vars = get_num_vars(a) + get_num_vars(b); + term_size = 1 + get_term_size(a) + get_term_size(b); + } + Optimization::Le([a, b]) => { + curvature = curvature::of_le(get_curvature(a), get_curvature(b)); + best = format!("(le {} {})", get_best(a), get_best(b)).parse().unwrap(); + num_vars = get_num_vars(a) + get_num_vars(b); + term_size = 1 + get_term_size(a) + get_term_size(b); + } + Optimization::Neg(a) => { + domain = domain::option_neg(get_domain(a)); + is_constant = get_is_constant(a); + curvature = curvature::of_neg(get_curvature(a)); + best = format!("(neg {})", get_best(a)).parse().unwrap(); + num_vars = get_num_vars(a); + term_size = 1 + get_term_size(a); + } + Optimization::Inv(a) => { + domain = domain::option_inv(get_domain(a)); + is_constant = get_is_constant(a); + let da_pos = domain::option_is_pos(get_domain(a).as_ref()); + if da_pos { + curvature = curvature::of_convex_nonincreasing_fn(get_curvature(a)); + } + best = format!("(inv {})", get_best(a)).parse().unwrap(); + num_vars = get_num_vars(a); + term_size = 1 + get_term_size(a); + } + Optimization::Abs(a) => { + domain = domain::option_abs(get_domain(a)); + is_constant = get_is_constant(a); + let da_nonneg = domain::option_is_nonneg(get_domain(a).as_ref()); + let da_nonpos = domain::option_is_nonpos(get_domain(a).as_ref()); + if da_nonneg { + curvature = curvature::of_convex_nondecreasing_fn(get_curvature(a)); + } else if da_nonpos { + curvature = curvature::of_convex_nonincreasing_fn(get_curvature(a)); + } else { + curvature = curvature::of_convex_none_fn(get_curvature(a)); + } + best = format!("(abs {})", get_best(a)).parse().unwrap(); + num_vars = get_num_vars(a); + term_size = 1 + get_term_size(a); + } + Optimization::Sqrt(a) => { + domain = domain::option_sqrt(get_domain(a)); + is_constant = get_is_constant(a); + let da_nonneg = domain::option_is_nonneg(get_domain(a).as_ref()); + if da_nonneg { + curvature = curvature::of_concave_nondecreasing_fn(get_curvature(a)); + } + best = format!("(sqrt {})", get_best(a)).parse().unwrap(); + num_vars = get_num_vars(a); + term_size = 1 + get_term_size(a); + } + Optimization::Log(a) => { + domain = domain::option_log(get_domain(a)); + is_constant = get_is_constant(a); + let da_nonneg = domain::option_is_nonneg(get_domain(a).as_ref()); + if da_nonneg { + curvature = curvature::of_concave_nondecreasing_fn(get_curvature(a)); + } + best = format!("(log {})", get_best(a)).parse().unwrap(); + num_vars = get_num_vars(a); + term_size = 1 + get_term_size(a); + } + Optimization::Exp(a) => { + domain = domain::option_exp(get_domain(a)); + is_constant = get_is_constant(a); + curvature = curvature::of_convex_nondecreasing_fn(get_curvature(a)); + best = format!("(exp {})", get_best(a)).parse().unwrap(); + num_vars = get_num_vars(a); + term_size = 1 + get_term_size(a); + } + Optimization::XExp(a) => { + domain = domain::option_xexp(get_domain(a)); + is_constant = get_is_constant(a); + let da_nonneg = domain::option_is_nonneg(get_domain(a).as_ref()); + if da_nonneg { + curvature = curvature::of_convex_nondecreasing_fn(get_curvature(a)); + } + best = format!("(xexp {})", get_best(a)).parse().unwrap(); + num_vars = get_num_vars(a); + term_size = 1 + get_term_size(a); + } + Optimization::Entr(a) => { + domain = domain::option_entr(get_domain(a)); + is_constant = get_is_constant(a); + let da_pos = domain::option_is_pos(get_domain(a).as_ref()); + if da_pos { + curvature = curvature::of_concave_none_fn(get_curvature(a)); + } + best = format!("(entr {})", get_best(a)).parse().unwrap(); + num_vars = get_num_vars(a); + term_size = 1 + get_term_size(a); + } + Optimization::Min([a, b]) => { + domain = domain::option_min(get_domain(a), get_domain(b)); + is_constant = get_is_constant(a) && get_is_constant(b); + let curvature_a = curvature::of_concave_nondecreasing_fn(get_curvature(a)); + let curvature_b = curvature::of_concave_nondecreasing_fn(get_curvature(b)); + curvature = curvature::join(curvature_a, curvature_b); + num_vars = get_num_vars(a) + get_num_vars(b); + term_size = 1 + get_term_size(a) + get_term_size(b); + } + Optimization::Max([a, b]) => { + domain = domain::option_max(get_domain(a), get_domain(b)); + is_constant = get_is_constant(a) && get_is_constant(b); + let curvature_a = curvature::of_convex_nondecreasing_fn(get_curvature(a)); + let curvature_b = curvature::of_convex_nondecreasing_fn(get_curvature(b)); + curvature = curvature::join(curvature_a, curvature_b); + best = format!("(max {} {})", get_best(a), get_best(b)).parse().unwrap(); + num_vars = get_num_vars(a) + get_num_vars(b); + term_size = 1 + get_term_size(a) + get_term_size(b); + } + Optimization::Add([a, b]) => { + domain = domain::option_add(get_domain(a), get_domain(b)); + is_constant = get_is_constant(a) && get_is_constant(b); + curvature = curvature::of_add(get_curvature(a), get_curvature(b)); + best = format!("(add {} {})", get_best(a), get_best(b)).parse().unwrap(); + num_vars = get_num_vars(a) + get_num_vars(b); + term_size = 1 + get_term_size(a) + get_term_size(b); + } + Optimization::Sub([a, b]) => { + domain = domain::option_sub(get_domain(a), get_domain(b)); + is_constant = get_is_constant(a) && get_is_constant(b); + curvature = curvature::of_sub(get_curvature(a), get_curvature(b)); + best = format!("(sub {} {})", get_best(a), get_best(b)).parse().unwrap(); + num_vars = get_num_vars(a) + get_num_vars(b); + term_size = 1 + get_term_size(a) + get_term_size(b); + } + Optimization::Mul([a, b]) => { + domain = domain::option_mul(get_domain(a), get_domain(b)); + is_constant = get_is_constant(a) && get_is_constant(b); + let da_o = get_domain(a); + let db_o = get_domain(b); + match (get_is_constant(a), get_is_constant(b)) { + (true, true) => { + curvature = Curvature::Constant + } + (true, false) => { + if let Some(da) = da_o { + curvature = curvature::of_mul_by_const(get_curvature(b), da) + } + } + (false, true) => { + if let Some(db) = db_o { + curvature = curvature::of_mul_by_const(get_curvature(a), db) + } + } + _ => { } + } + best = format!("(mul {} {})", get_best(a), get_best(b)).parse().unwrap(); + num_vars = get_num_vars(a) + get_num_vars(b); + term_size = 1 + get_term_size(a) + get_term_size(b); + } + Optimization::Div([a, b]) => { + domain = domain::option_div(get_domain(a), get_domain(b)); + is_constant = get_is_constant(a) && get_is_constant(b); + let db_o = get_domain(b); + match (get_is_constant(a), get_is_constant(b)) { + (true, true) => { + if let Some(db) = db_o { + if domain::does_not_contain_zero(&db) { + curvature = Curvature::Constant + } + } + } + (false, true) => { + if let Some(db) = db_o { + if domain::does_not_contain_zero(&db) { + curvature = curvature::of_mul_by_const(get_curvature(a), db) + } + } + } + _ => { } + }; + best = format!("(div {} {})", get_best(a), get_best(b)).parse().unwrap(); + num_vars = get_num_vars(a) + get_num_vars(b); + term_size = 1 + get_term_size(a) + get_term_size(b); + } + Optimization::Pow([a, b]) => { + domain = domain::option_pow(get_domain(a), get_domain(b)); + is_constant = get_is_constant(a) && get_is_constant(b); + if get_is_constant(b) { + if let Some(db) = get_domain(b) { + // Domain guards already in `of_pow_by_const`. + curvature = curvature::of_pow_by_const(get_curvature(a), db, get_domain(a)) + } + } + best = format!("(pow {} {})", get_best(a), get_best(b)).parse().unwrap(); + num_vars = get_num_vars(a) + get_num_vars(b); + term_size = 1 + get_term_size(a) + get_term_size(b); + } + Optimization::QOL([a, b]) => { + domain = domain::option_quad_over_lin(get_domain(a), get_domain(b)); + is_constant = get_is_constant(a) && get_is_constant(b); + let curvature_num = curvature::of_convex_none_fn(get_curvature(a)); + let db_pos = domain::option_is_pos(get_domain(b).as_ref()); + if db_pos { + let curvature_den = curvature::of_convex_nonincreasing_fn(get_curvature(b)); + curvature = curvature::join(curvature_num, curvature_den); + } + best = format!("(qol {} {})", get_best(a), get_best(b)).parse().unwrap(); + num_vars = get_num_vars(a) + get_num_vars(b); + term_size = 1 + get_term_size(a) + get_term_size(b); + } + Optimization::Geo([a, b]) => { + domain = domain::option_geo_mean(get_domain(a), get_domain(b)); + is_constant = get_is_constant(a) && get_is_constant(b); + let da_nonneg = domain::option_is_nonneg(get_domain(a).as_ref()); + let db_nonneg = domain::option_is_nonneg(get_domain(b).as_ref()); + if da_nonneg && db_nonneg { + let curvature_a = curvature::of_concave_nondecreasing_fn(get_curvature(a)); + let curvature_b = curvature::of_concave_nondecreasing_fn(get_curvature(b)); + curvature = curvature::join(curvature_a, curvature_b); + } + best = format!("(geo {} {})", get_best(a), get_best(b)).parse().unwrap(); + num_vars = get_num_vars(a) + get_num_vars(b); + term_size = 1 + get_term_size(a) + get_term_size(b); + } + Optimization::LSE([a, b]) => { + domain = domain::option_log_sum_exp(get_domain(a), get_domain(b)); + is_constant = get_is_constant(a) && get_is_constant(b); + let curvature_a = curvature::of_convex_nondecreasing_fn(get_curvature(a)); + let curvature_b = curvature::of_convex_nondecreasing_fn(get_curvature(b)); + curvature = curvature::join(curvature_a, curvature_b); + best = format!("(lse {} {})", get_best(a), get_best(b)).parse().unwrap(); + num_vars = get_num_vars(a) + get_num_vars(b); + term_size = 1 + get_term_size(a) + get_term_size(b); + } + Optimization::Norm2([a, b]) => { + domain = domain::option_norm2(get_domain(a), get_domain(b)); + is_constant = get_is_constant(a) && get_is_constant(b); + let curvature_a = curvature::of_convex_none_fn(get_curvature(a)); + let curvature_b = curvature::of_convex_none_fn(get_curvature(b)); + curvature = curvature::join(curvature_a, curvature_b); + best = format!("(norm2 {} {})", get_best(a), get_best(b)).parse().unwrap(); + num_vars = get_num_vars(a) + get_num_vars(b); + term_size = 1 + get_term_size(a) + get_term_size(b); + } + Optimization::Var(a) => { + // Assume that after var there is always a symbol. + match egraph[*a].nodes[0] { + Optimization::Symbol(s) => { + let s_s = format!("{}", s); + match domains_map.get(&s_s) { + Some(d) => { domain = Some(d.clone()); } + _ => { domain = Some(domain::free_dom()); } + } + curvature = Curvature::Affine; + best = format!("(var {})", s).parse().unwrap(); + num_vars = 1; + term_size = 1; + } + _ => {} + } + } + Optimization::Param(a) => { + match egraph[*a].nodes[0] { + Optimization::Symbol(s) => { + let s_s = format!("{}", s); + match domains_map.get(&s_s) { + Some(d) => { domain = Some(d.clone()); } + _ => { domain = Some(domain::free_dom()); } + } + curvature = Curvature::Constant; + best = format!("(param {})", s).parse().unwrap(); + num_vars = 0; + term_size = 1; + } + _ => {} + } + // NOTE(RFM): parameters are treated as constants. + is_constant = true; + + } + Optimization::Pi(_) => { + domain = Some(Domain::make_singleton(std::f64::consts::PI)); + is_constant = true; + curvature = Curvature::Constant; + best = format!("pi").parse().unwrap(); + num_vars = 0; + term_size = 1; + } + Optimization::Symbol(s) => { + best = format!("{}", s).parse().unwrap(); + num_vars = 0; + term_size = 0; + } + Optimization::Constant(f) => { + domain = Some(Domain::make_singleton((*f).into_inner())); + is_constant = true; + curvature = Curvature::Constant; + best = format!("{}", f).parse().unwrap(); + num_vars = 0; + term_size = 1; + } + } + + DataWithCost { domain, is_constant, curvature, best, num_vars, term_size } + } +} + pub fn is_gt_zero(var: &str) -> impl Fn(&mut EGraph, Id, &Subst) -> bool { let var = var.parse().unwrap(); move |egraph, _, subst| { @@ -292,3 +770,20 @@ pub fn is_nat(var: &str) -> impl Fn(&mut EGraph, Id, &Subst) -> bool { return false; } } + +pub fn is_le(var1: &str, var2: &str) -> impl Fn(&mut EGraph, Id, &Subst) -> bool { + let var1 = var1.parse().unwrap(); + let var2 = var2.parse().unwrap(); + move |egraph, _, subst| { + if let Some(d1) = &egraph[subst[var1]].data.domain { + if let Some(d2) = &egraph[subst[var2]].data.domain { + return Domain::subseteq(d1, d2); + } + } + return false; + } +} + +pub fn is_ge(var1: &str, var2: &str) -> impl Fn(&mut EGraph, Id, &Subst) -> bool { + is_le(var2, var1) +} diff --git a/egg-pre-dcp/src/rules.rs b/egg-pre-dcp/src/rules.rs index 6fdf2cfc..680c0b6f 100644 --- a/egg-pre-dcp/src/rules.rs +++ b/egg-pre-dcp/src/rules.rs @@ -8,6 +8,8 @@ use optimization::is_ge_zero as is_ge_zero; use optimization::is_le_zero as is_le_zero; use optimization::is_not_zero as is_not_zero; use optimization::is_nat as is_nat; +use optimization::is_le as is_le; +use optimization::is_ge as is_ge; pub fn rules() -> Vec> { vec![ @@ -27,6 +29,9 @@ pub fn rules() -> Vec> { vec![ rw!("log_eq_log"; "(eq ?a ?b)" => "(eq (log ?a) (log ?b))" if is_gt_zero("?a") if is_gt_zero("?b")), + + rw!("log_eq_log-rev"; "(eq (log ?a) (log ?b))" => "(eq ?a ?b)" + if is_gt_zero("?a") if is_gt_zero("?b")), /* Less than or equal rules. */ @@ -176,6 +181,9 @@ pub fn rules() -> Vec> { vec![ rw!("one_div_eq_pow_neg_one"; "(div 1 ?a)" => "(pow ?a (neg 1))" if is_gt_zero("?a")), + + rw!("one_div_eq_pow_neg_one-rev"; "(pow ?a (neg 1))" => "(div 1 ?a)" + if is_gt_zero("?a")), rw!("sqrt_eq_rpow"; "(sqrt ?a)" => "(pow ?a 0.5)"), @@ -197,9 +205,6 @@ pub fn rules() -> Vec> { vec![ rw!("binomial_two"; "(pow (add ?a ?b) 2)" => "(add (pow ?a 2) (add (mul 2 (mul ?a ?b)) (pow ?b 2)))"), - // rw!("rpow_eq_mul_rpow_pred"; "(pow ?a ?b)" => "(mul ?a (pow ?a (sub ?b 1)))" - // if is_not_zero("?a")), - rw!("inv_eq_pow_neg_one"; "(inv ?a)" => "(pow ?a (neg 1))" if is_not_zero("?a")), @@ -256,6 +261,17 @@ pub fn rules() -> Vec> { vec![ rw!("abs_nonpos"; "(abs ?a)" => "(neg ?a)" if is_le_zero("?a")), + + /* Min and max rules. */ + + rw!("min_eq_left"; "(min ?a ?b)" => "?a" if is_le("?a", "?b")), + + rw!("min_eq_right"; "(min ?a ?b)" => "?b" if is_le("?b", "?a")), + + rw!("max_eq_left"; "(max ?a ?b)" => "?a" if is_ge("?a", "?b")), + + rw!("max_eq_right"; "(max ?a ?b)" => "?b" if is_ge("?b", "?a")), + /* Atom folding and unfolding rules. */ diff --git a/egg-pre-dcp/src/test_util.rs b/egg-pre-dcp/src/test_util.rs index 2559151d..1e73c060 100644 --- a/egg-pre-dcp/src/test_util.rs +++ b/egg-pre-dcp/src/test_util.rs @@ -4,6 +4,7 @@ use domain::Domain as Domain; use crate::extract; use extract::Minimization as Minimization; use extract::get_steps as get_steps; +use extract::get_steps_maybe_node_limit as get_steps_maybe_node_limit; use extract::get_steps_from_string_maybe_node_limit as get_steps_from_string_maybe_node_limit; fn make(obj: &str, constrs: Vec<&str>) -> Minimization { @@ -18,11 +19,20 @@ fn make(obj: &str, constrs: Vec<&str>) -> Minimization { }; } -fn convexify_check_with_domain_maybe_print(domains : Vec<(&str, Domain)>, obj: &str, constrs: Vec<&str>, print: bool) { +fn pre_dcp_check_with_domain_maybe_print(domains : Vec<(&str, Domain)>, obj: &str, constrs: Vec<&str>, print: bool) { let prob = make(obj, constrs); let domains = domains.iter().map(|(s, d)| ((*s).to_string(), d.clone())).collect(); - let steps = get_steps(prob, domains, true); + let steps = + match std::env::var("EGG_PRE_DCP_NODE_LIMIT") { + Result::Ok(v) => { + let node_limit = v.parse::().unwrap(); + get_steps_maybe_node_limit(prob, domains, print, Some(node_limit)) + } + Result::Err(_) => { + get_steps(prob, domains, print) + } + }; if steps.is_none() { panic!("Test failed, could not rewrite target into DCP form."); } @@ -31,28 +41,40 @@ fn convexify_check_with_domain_maybe_print(domains : Vec<(&str, Domain)>, obj: & } } -pub fn convexify_check_with_domain(domains : Vec<(&str, Domain)>, obj: &str, constrs: Vec<&str>) { - convexify_check_with_domain_maybe_print(domains, obj, constrs, false) +pub fn pre_dcp_check_with_domain(domains : Vec<(&str, Domain)>, obj: &str, constrs: Vec<&str>) { + pre_dcp_check_with_domain_maybe_print(domains, obj, constrs, false) } -pub fn convexify_check_with_domain_and_print(domains : Vec<(&str, Domain)>, obj: &str, constrs: Vec<&str>) { - convexify_check_with_domain_maybe_print(domains, obj, constrs, true) +pub fn pre_dcp_check_with_domain_and_print(domains : Vec<(&str, Domain)>, obj: &str, constrs: Vec<&str>) { + pre_dcp_check_with_domain_maybe_print(domains, obj, constrs, true) } -pub fn convexify_check(obj: &str, constrs: Vec<&str>) { - convexify_check_with_domain_maybe_print(vec![], obj, constrs, false) +pub fn pre_dcp_check(obj: &str, constrs: Vec<&str>) { + pre_dcp_check_with_domain_maybe_print(vec![], obj, constrs, false) } -pub fn convexify_check_and_print(obj: &str, constrs: Vec<&str>) { - convexify_check_with_domain_maybe_print(vec![], obj, constrs, true) +pub fn pre_dcp_check_and_print(obj: &str, constrs: Vec<&str>) { + pre_dcp_check_with_domain_maybe_print(vec![], obj, constrs, true) } // Used to test out-of-context expressions. -fn convexify_check_expression_with_domain_maybe_print_maybe_node_limit( +fn pre_dcp_check_expression_with_domain_maybe_print_maybe_node_limit( domains : Vec<(&str, Domain)>, s: &str, print: bool, node_limit: Option) { + // NOTE: Even if the node limit is passed as an argument, the environment variable is used if it + // is set, hence overriding the argument. + let node_limit = + match std::env::var("EGG_PRE_DCP_NODE_LIMIT") { + Result::Ok(v) => { + let node_limit = v.parse::().unwrap(); + Some(node_limit) + } + Result::Err(_) => { + node_limit + } + }; let domains = domains.iter().map(|(s, d)| ((*s).to_string(), d.clone())).collect(); let steps = get_steps_from_string_maybe_node_limit(s, domains, true, node_limit); @@ -64,34 +86,34 @@ fn convexify_check_expression_with_domain_maybe_print_maybe_node_limit( } } -fn convexify_check_expression_with_domain_maybe_print(domains : Vec<(&str, Domain)>, s: &str, print: bool) { - convexify_check_expression_with_domain_maybe_print_maybe_node_limit(domains, s, print, None); +fn pre_dcp_check_expression_with_domain_maybe_print(domains : Vec<(&str, Domain)>, s: &str, print: bool) { + pre_dcp_check_expression_with_domain_maybe_print_maybe_node_limit(domains, s, print, None); } -pub fn convexify_check_expression_with_domain(domains : Vec<(&str, Domain)>,s: &str) { - convexify_check_expression_with_domain_maybe_print(domains, s, false); +pub fn pre_dcp_check_expression_with_domain(domains : Vec<(&str, Domain)>,s: &str) { + pre_dcp_check_expression_with_domain_maybe_print(domains, s, false); } -pub fn convexify_check_expression_with_domain_and_print(domains : Vec<(&str, Domain)>,s: &str) { - convexify_check_expression_with_domain_maybe_print(domains, s, true); +pub fn pre_dcp_check_expression_with_domain_and_print(domains : Vec<(&str, Domain)>,s: &str) { + pre_dcp_check_expression_with_domain_maybe_print(domains, s, true); } -fn convexify_check_expression_with_domain_and_node_limit_maybe_print(domains : Vec<(&str, Domain)>, s: &str, print: bool, node_limit: usize) { - convexify_check_expression_with_domain_maybe_print_maybe_node_limit(domains, s, print, Some(node_limit)); +fn pre_dcp_check_expression_with_domain_and_node_limit_maybe_print(domains : Vec<(&str, Domain)>, s: &str, print: bool, node_limit: usize) { + pre_dcp_check_expression_with_domain_maybe_print_maybe_node_limit(domains, s, print, Some(node_limit)); } -pub fn convexify_check_expression_with_domain_and_node_limit(domains : Vec<(&str, Domain)>,s: &str, node_limit: usize) { - convexify_check_expression_with_domain_and_node_limit_maybe_print(domains, s, false, node_limit); +pub fn pre_dcp_check_expression_with_domain_and_node_limit(domains : Vec<(&str, Domain)>,s: &str, node_limit: usize) { + pre_dcp_check_expression_with_domain_and_node_limit_maybe_print(domains, s, false, node_limit); } -pub fn convexify_check_expression_with_domain_and_node_limit_and_print(domains : Vec<(&str, Domain)>,s: &str, node_limit: usize) { - convexify_check_expression_with_domain_and_node_limit_maybe_print(domains, s, true, node_limit); +pub fn pre_dcp_check_expression_with_domain_and_node_limit_and_print(domains : Vec<(&str, Domain)>,s: &str, node_limit: usize) { + pre_dcp_check_expression_with_domain_and_node_limit_maybe_print(domains, s, true, node_limit); } -pub fn convexify_check_expression(s: &str) { - convexify_check_expression_with_domain_maybe_print(vec![], s, false); +pub fn pre_dcp_check_expression(s: &str) { + pre_dcp_check_expression_with_domain_maybe_print(vec![], s, false); } -pub fn convexify_check_expression_and_print(s: &str) { - convexify_check_expression_with_domain_maybe_print(vec![], s, true); +pub fn pre_dcp_check_expression_and_print(s: &str) { + pre_dcp_check_expression_with_domain_maybe_print(vec![], s, true); } diff --git a/egg-pre-dcp/tests/test_almost_dgp.rs b/egg-pre-dcp/tests/test_almost_dgp.rs index c86f6a6c..bf71a7ac 100644 --- a/egg-pre-dcp/tests/test_almost_dgp.rs +++ b/egg-pre-dcp/tests/test_almost_dgp.rs @@ -1,14 +1,49 @@ /*! -Tests from geometric programming that do not follow the DGP rules. +Tests from geometric programming that do not follow the DGP rules strictly (often because of a +negative sign that can be removed by rewritting the constraints appropriately). !*/ +mod test_almost_dgp { + use egg_pre_dcp::test_util::{*}; +#[test] +fn test_agp1() { + pre_dcp_check_and_print( + "(exp (var u))", + vec![ + "(le (sub (pow (exp (var u)) 2) (div 10123 1000)) 0)" + ]); +} + #[test] fn test_agp2() { - convexify_check( - "(exp (var x))", + pre_dcp_check_and_print( + "(exp (var u))", + vec![ + "(le (sub (mul (exp (var u)) (exp (var v))) (div 2691 500)) 0)" + ]); +} + +#[test] +fn test_agp3() { + pre_dcp_check_and_print( + "(add (add (exp (var u)) (exp (var v))) (exp (var w)))", vec![ - "(le (sub (mul (exp (var x)) (exp (var y))) (div 2691 500)) 0)" + "(le 2 (exp (var u)))", + "(le (exp (var u)) 3)", + "(le (sqrt (exp (var u))) (sub (pow (exp (var u)) 2) (div (mul 6 (exp (var v))) (exp (var w)))))", + "(eq (mul (exp (var u)) (exp (var v))) (exp (var w)))" ]); } + +#[test] +fn test_agp4() { + pre_dcp_check_and_print( + "(div 1 (mul (exp (var u)) (exp (var v))))", + vec![ + "(le (mul (exp (var u)) (exp (var v))) (sub (sub 2 (exp (var u))) (exp (var v))))" + ]); +} + +} diff --git a/egg-pre-dcp/tests/test_cost_function.rs b/egg-pre-dcp/tests/test_cost_function.rs index 348764f6..849c2db3 100644 --- a/egg-pre-dcp/tests/test_cost_function.rs +++ b/egg-pre-dcp/tests/test_cost_function.rs @@ -6,7 +6,7 @@ use egg_pre_dcp::test_util::{*}; #[test] fn test_cost_function_number_of_variable_occurences() { - convexify_check( + pre_dcp_check( "0", vec![ "(le (var x) (sub 1 (var x)))" @@ -15,7 +15,7 @@ fn test_cost_function_number_of_variable_occurences() { #[test] fn test_cost_function_number_of_variable_occurences_2() { - convexify_check( + pre_dcp_check( "0", vec![ "(le (add (mul 2 (var x)) (var x)) 0)" diff --git a/egg-pre-dcp/tests/test_dgp.rs b/egg-pre-dcp/tests/test_dgp.rs index 6b4a1e40..3b8cde66 100644 --- a/egg-pre-dcp/tests/test_dgp.rs +++ b/egg-pre-dcp/tests/test_dgp.rs @@ -2,54 +2,102 @@ Tests from geometric programming. !*/ -use egg_pre_dcp::domain; +mod test_dgp { use egg_pre_dcp::test_util::{*}; +#[test] +fn test_gp1() { + pre_dcp_check_and_print( + "(exp (var u))", + vec![ + "(le (pow (exp (var u)) 2) (div 10123 1000))" + ]); +} + +#[test] +fn test_gp2() { + pre_dcp_check_and_print( + "(exp (var u))", + vec![ + "(le (mul (exp (var u)) (exp (var v))) (div 2691 500))" + ]); +} + #[test] fn test_gp3() { - convexify_check( - "(exp (var x))", + pre_dcp_check_and_print( + "(exp (var u))", vec![ - "(le (sqrt (add (mul (exp (var x)) (exp (var x))) (exp (var y)))) 1)" + "(le (sqrt (add (mul (exp (var u)) (exp (var u))) (exp (var v)))) 1)" ]); } #[test] fn test_gp4() { - convexify_check( - "(div 1 (div (exp (var x)) (exp (var y))))", + pre_dcp_check_and_print( + "(div 1 (div (exp (var u)) (exp (var v))))", + vec![ + "(le 2 (exp (var u)))", + "(le (exp (var u)) 3)", + "(le (add (pow (exp (var u)) 2) (div (mul 3 (exp (var v))) (exp (var w)))) (sqrt (exp (var u))))", + "(eq (div (exp (var u)) (exp (var v))) (pow (exp (var w)) 2))" + ]); +} + +#[test] +fn test_gp5() { + pre_dcp_check_and_print( + "(div 1 (div (exp (var u)) (exp (var v))))", vec![ - "(le 2 (exp (var x)))", - "(le (exp (var x)) 3)", - "(le (add (pow (exp (var x)) 2) (div (mul 3 (exp (var y))) (exp (var z)))) (sqrt (exp (var x))))", - "(eq (div (exp (var x)) (exp (var y))) (pow (exp (var z)) 2))" + "(le 2 (exp (var u)))", + "(le (exp (var u)) 3)", + "(le (add (pow (exp (var u)) 2) (div (mul 3 (exp (var v))) (exp (var w)))) (sqrt (exp (var v))))", + "(eq (div (exp (var u)) (exp (var v))) (pow (exp (var w)) 2))" ]); } #[test] fn test_gp6() { - convexify_check_with_domain( + pre_dcp_check_and_print( + "(div 1 (div (exp (var u)) (exp (var v))))", vec![ - ("Aflr", domain::pos_dom()), - ("α" , domain::pos_dom()), - ("β" , domain::pos_dom()), - ("γ" , domain::pos_dom()), - ("δ" , domain::pos_dom())], - "(div 1 (mul (mul (exp (var h)) (exp (var w))) (exp (var d))))", + "(le 2 (exp (var u)))", + "(le (exp (var u)) 3)", + "(le (add (pow (exp (var u)) 2) (div (mul 3 (exp (var v))) (exp (var w)))) (mul 5 (sqrt (exp (var v)))))", + "(eq (mul (exp (var u)) (exp (var v))) (pow (exp (var w)) 2))" + ]); +} + +#[test] +fn test_gp7() { + pre_dcp_check_and_print( + "(div 1 (mul (mul (exp (var h')) (exp (var w'))) (exp (var d'))))", vec![ - "(le (mul 2 (add (mul (exp (var h)) (exp (var d))) (mul (exp (var w)) (exp (var d))))) (param Awall))", - "(le (mul (exp (var w)) (exp (var d))) (param Aflr))", - "(le (param α) (div (exp (var h)) (exp (var w))))", - "(le (div (exp (var h)) (exp (var w))) (param β))", - "(le (param γ) (div (exp (var d)) (exp (var w))))", - "(le (div (exp (var d)) (exp (var w))) (param δ))" + "(le (mul 2 (add (mul (exp (var h')) (exp (var d'))) (mul (exp (var w')) (exp (var d'))))) 100)", + "(le (mul (exp (var w')) (exp (var d'))) 10)", + "(le (div 1 2) (div (exp (var h')) (exp (var w'))))", + "(le (div (exp (var h')) (exp (var w'))) (div 1 2))", + "(le 5 (div (exp (var d')) (exp (var w'))))", + "(le (div (exp (var d')) (exp (var w'))) 6)" ]); } #[test] fn test_gp8() { - convexify_check( + pre_dcp_check_and_print( + "(add (add (mul (mul (div 1 (exp (var u))) (div 1 (sqrt (exp (var v))))) (div 1 (exp (var w)))) (mul (mul (div 23 10) (exp (var u))) (exp (var w)))) (mul (mul (mul 4 (exp (var u))) (exp (var v))) (exp (var w))))", + vec![ + "(le (add (mul (mul (div 1 3) (div 1 (pow (exp (var u)) 2))) (div 1 (pow (exp (var v)) 2))) (mul (mul (div 4 3) (sqrt (exp (var v)))) (div 1 (exp (var w))))) 1)", + "(le (add (add (exp (var u)) (mul 2 (exp (var v)))) (mul 3 (exp (var w)))) 1)", + "(eq (mul (mul (div 1 2) (exp (var u))) (exp (var v))) 1)" + ]); + +} + +#[test] +fn test_gp9() { + pre_dcp_check_and_print( "(mul (mul 2 (exp (var A))) (norm2 (exp (var w)) (exp (var h))))", vec![ "(le (mul (div (mul 10 (norm2 (exp (var w)) (exp (var h)))) 2) (exp (var h))) (mul (div 1 2) (exp (var A))))", @@ -62,3 +110,32 @@ fn test_gp8() { "(le (sqrt (add (div (exp (var A)) (div 314159 50000)) (pow (exp (var r)) 2))) 10)" ]); } + +} + +mod test_dgp_other { + +use egg_pre_dcp::domain; + +use egg_pre_dcp::test_util::{*}; + +#[test] +fn test_gp7_with_params() { + pre_dcp_check_with_domain( + vec![ + ("Aflr", domain::pos_dom()), + ("α" , domain::pos_dom()), + ("β" , domain::pos_dom()), + ("γ" , domain::pos_dom()), + ("δ" , domain::pos_dom())], + "(div 1 (mul (mul (exp (var h)) (exp (var w))) (exp (var d))))", + vec![ + "(le (mul 2 (add (mul (exp (var h)) (exp (var d))) (mul (exp (var w)) (exp (var d))))) (param Awall))", + "(le (mul (exp (var w)) (exp (var d))) (param Aflr))", + "(le (param α) (div (exp (var h)) (exp (var w))))", + "(le (div (exp (var h)) (exp (var w))) (param β))", + "(le (param γ) (div (exp (var d)) (exp (var w))))", + "(le (div (exp (var d)) (exp (var w))) (param δ))" + ]); +} +} diff --git a/egg-pre-dcp/tests/test_domain.rs b/egg-pre-dcp/tests/test_domain.rs index f51b762d..1362471d 100644 --- a/egg-pre-dcp/tests/test_domain.rs +++ b/egg-pre-dcp/tests/test_domain.rs @@ -2,6 +2,8 @@ Tests for extended interval arithmetic. !*/ +mod test_domain { + use egg_pre_dcp; use egg_pre_dcp::domain; @@ -694,3 +696,5 @@ fn is_nonneg_1() { // [1, 1] is non-negative assert!(domain::is_nonneg(&Domain::make_singleton(1.0))); } + +} diff --git a/egg-pre-dcp/tests/test_dqcp.rs b/egg-pre-dcp/tests/test_dqcp.rs index ac74543a..a689fdde 100644 --- a/egg-pre-dcp/tests/test_dqcp.rs +++ b/egg-pre-dcp/tests/test_dqcp.rs @@ -2,6 +2,8 @@ Tests from quasiconvex programming. !*/ +mod test_dqcp { + use egg_pre_dcp::domain; use domain::Domain as Domain; @@ -10,7 +12,7 @@ use egg_pre_dcp::test_util::{*}; #[test] fn test_qcp1() { - convexify_check_with_domain( + pre_dcp_check_with_domain_and_print( vec![("x", domain::pos_dom())], "(var x)", vec![ @@ -21,7 +23,7 @@ fn test_qcp1() { #[test] fn test_qcp2() { let d = Domain::make_oc(domain::zero(), domain::one()); - convexify_check_with_domain( + pre_dcp_check_with_domain_and_print( vec![("x", d)], "(sqrt (sub (div 1 (pow (var x) 2)) 1))", vec![ @@ -29,15 +31,49 @@ fn test_qcp2() { ]) } +#[test] +fn test_qcp3() { + let dx = Domain::make_singleton(12.0); + let dy = Domain::make_cc(domain::make_float(0.001), domain::make_float(6.0)); + pre_dcp_check_with_domain_and_print( + vec![("x", dx), ("y", dy)], + "0", + vec![ + "(le (div (var x) (var y)) 3)" + ]) +} + +#[test] +fn test_qcp4() { + let dx = Domain::make_ci(domain::make_float(10.0)); + pre_dcp_check_with_domain_and_print( + vec![("x", dx)], + "(neg (var x))", + vec![ + "(le (div (sqrt (add (pow (add (var x) 1) 2) 4)) (sqrt (add (pow (var x) 2) 100))) 1)" + ]) +} + +} + +mod test_dqcp_other { + +use egg_pre_dcp::domain; +use domain::Domain as Domain; + +use egg_pre_dcp::test_util::{*}; + #[test] fn test_qcp2_with_params() { let dx = Domain::make_oc(domain::zero(), domain::one()); let da = Domain::make_ci(domain::zero()); let db = Domain::make_io(domain::one()); - convexify_check_with_domain( + pre_dcp_check_with_domain( vec![("x", dx), ("a", da), ("b", db)], "(sqrt (sub (div 1 (pow (var x) 2)) 1))", vec![ "(le (sub (mul (param a) (div 1 (var x))) (mul (sub 1 (param b)) (sqrt (sub 1 (pow (var x) 2))))) 0)" ]) } + +} diff --git a/egg-pre-dcp/tests/test_main_example.rs b/egg-pre-dcp/tests/test_main_example.rs index 3cc66e3c..cda8f81d 100644 --- a/egg-pre-dcp/tests/test_main_example.rs +++ b/egg-pre-dcp/tests/test_main_example.rs @@ -2,16 +2,20 @@ Test motivating example. !*/ +mod test_main_example { + use egg_pre_dcp::domain; use egg_pre_dcp::test_util::{*}; #[test] fn test_main_example() { - convexify_check_with_domain( + pre_dcp_check_with_domain_and_print( vec![("x", domain::pos_dom())], "(var x)", vec![ "(le (div 1 (sqrt (var x))) (exp (var x)))" ]); } + +} diff --git a/egg-pre-dcp/tests/test_misc.rs b/egg-pre-dcp/tests/test_misc.rs index fb27789a..087cf5ac 100644 --- a/egg-pre-dcp/tests/test_misc.rs +++ b/egg-pre-dcp/tests/test_misc.rs @@ -8,27 +8,27 @@ use egg_pre_dcp::test_util::{*}; #[test] fn test_norm2_with_one() { - convexify_check_expression_with_domain( + pre_dcp_check_expression_with_domain( vec![("x", domain::free_dom())], "(sqrt (add (pow (var x) 2) 1))"); } #[test] fn test_sqrt_pow4() { - convexify_check_expression_with_domain( + pre_dcp_check_expression_with_domain( vec![("x", domain::nonneg_dom())], "(sqrt (pow (var x) 4))"); } #[test] fn test_div_constant_simp() { - convexify_check_expression( + pre_dcp_check_expression( "(div (div (var x) 20) (div 7 20))"); } #[test] fn test_div_constant_le_simp() { - convexify_check_expression_with_domain( + pre_dcp_check_expression_with_domain( vec![("x", domain::nonneg_dom()), ("y", domain::pos_dom())], "(le (div (qol 1 (var y)) 20) (mul (div 7 20) (sqrt (var x))))"); } diff --git a/egg-pre-dcp/tests/test_quiz.rs b/egg-pre-dcp/tests/test_quiz.rs index a05f6406..33c91fce 100644 --- a/egg-pre-dcp/tests/test_quiz.rs +++ b/egg-pre-dcp/tests/test_quiz.rs @@ -3,75 +3,79 @@ Tests adapted from random (non-DCP) problems generated by the DCP quiz: https://dcp.stanford.edu/quiz !*/ +mod test_quiz { + use egg_pre_dcp::domain; use egg_pre_dcp::test_util::{*}; #[test] -fn test_quiz_1() { - convexify_check_expression_with_domain( +fn test_quiz1() { + pre_dcp_check_expression_with_domain_and_print( vec![("x", domain::pos_dom())], "(inv (inv (var x)))"); } #[test] -fn test_quiz_2() { - convexify_check_expression_with_domain( +fn test_quiz2() { + pre_dcp_check_expression_with_domain_and_print( vec![("x", domain::pos_dom()), ("y", domain::pos_dom())], "(neg (lse (log (var x)) (log (var y))))"); } #[test] -fn test_quiz_3() { - convexify_check_expression_with_domain( +fn test_quiz3() { + pre_dcp_check_expression_with_domain_and_print( vec![("x", domain::nonneg_dom())], "(pow (sqrt (var x)) 2)"); } #[test] -fn test_quiz_4() { - convexify_check_expression_with_domain( +fn test_quiz4() { + pre_dcp_check_expression_with_domain_and_print( vec![("x", domain::nonneg_dom())], "(neg (abs (sqrt (abs (var x)))))"); } #[test] -fn test_quiz_5() { - convexify_check_expression_with_domain( +fn test_quiz5() { + pre_dcp_check_expression_with_domain_and_print( vec![("x", domain::free_dom())], "(div 1 (exp (var x)))"); } #[test] -fn test_quiz_6() { - convexify_check_expression_with_domain( +fn test_quiz6() { + pre_dcp_check_expression_with_domain_and_print( vec![("x", domain::nonneg_dom())], "(neg (log (pow (mul 364 (var x)) 2)))"); } #[test] -fn test_quiz_7() { - convexify_check_expression_with_domain( +fn test_quiz7() { + pre_dcp_check_expression_with_domain_and_print( vec![("x", domain::pos_dom())], "(pow (geo (add (var x) 2) (div 1 (var x))) 2)"); } #[test] -fn test_quiz_8() { - convexify_check_expression_with_domain( +fn test_quiz8() { + pre_dcp_check_expression_with_domain_and_print( vec![("x", domain::nonneg_dom())], "(neg (log (abs (var x))))"); } #[test] -fn test_quiz_9() { - convexify_check_expression_with_domain( +fn test_quiz9() { + pre_dcp_check_expression_with_domain_and_print( vec![("x", domain::pos_dom()), ("y", domain::pos_dom())], "(div 1 (qol (inv (var x)) (inv (var y)))))"); } #[test] -fn test_quiz_10() { - convexify_check_expression( +fn test_quiz10() { + pre_dcp_check_expression_and_print( "(pow (log (exp (var x))) 2)"); } + +} diff --git a/egg-pre-dcp/tests/test_rules.rs b/egg-pre-dcp/tests/test_rules.rs index 436317d5..92d2c4f6 100644 --- a/egg-pre-dcp/tests/test_rules.rs +++ b/egg-pre-dcp/tests/test_rules.rs @@ -2,6 +2,8 @@ Tests for specific rewrite rules. !*/ +mod test_rules { + use egg_pre_dcp::domain; use egg_pre_dcp::domain::Domain as Domain; @@ -9,7 +11,7 @@ use egg_pre_dcp::test_util::{*}; #[test] fn test_log_le_log() { - convexify_check_with_domain( + pre_dcp_check_with_domain( vec![("x", domain::pos_dom()), ("y", domain::pos_dom())], "0", vec![ @@ -19,7 +21,7 @@ fn test_log_le_log() { #[test] fn test_sub_iff_add_le() { - convexify_check( + pre_dcp_check( "0", vec![ "(le (add 1 (var x)) (var x))", @@ -28,7 +30,7 @@ fn test_sub_iff_add_le() { #[test] fn test_log_le_log_rev() { - convexify_check( + pre_dcp_check( "0", vec![ "(le (exp (var x)) (exp (var y)))" @@ -37,7 +39,7 @@ fn test_log_le_log_rev() { #[test] fn test_exp_add() { - convexify_check_with_domain( + pre_dcp_check_with_domain( vec![("x", domain::pos_dom())], "0", vec![ @@ -47,7 +49,7 @@ fn test_exp_add() { #[test] fn test_exp_neg_eq_one_div_obj() { - convexify_check_with_domain( + pre_dcp_check_with_domain( vec![("x", Domain::make_ci(domain::one()))], "(mul (var x) (exp (neg (log (var x)))))", vec![ @@ -56,7 +58,7 @@ fn test_exp_neg_eq_one_div_obj() { #[test] fn test_exp_neg_eq_one_div_constr() { - convexify_check_with_domain( + pre_dcp_check_with_domain( vec![("x", Domain::make_ci(domain::one()))], "(le (mul (var x) (exp (neg (log (var x))))) (var x))", vec![ @@ -65,7 +67,7 @@ fn test_exp_neg_eq_one_div_constr() { #[test] fn test_log_mul_rev_constr() { - convexify_check_with_domain( + pre_dcp_check_with_domain( vec![("x", domain::pos_dom())], "0", vec![ @@ -75,7 +77,7 @@ fn test_log_mul_rev_constr() { #[test] fn test_exp_neg_eq_one_div_rev() { - convexify_check( + pre_dcp_check( "(div 1 (exp (var x)))", vec![ "(le 1 (var x))" @@ -84,7 +86,7 @@ fn test_exp_neg_eq_one_div_rev() { #[test] fn test_div_self() { - convexify_check_with_domain( + pre_dcp_check_with_domain( vec![("x", domain::pos_dom())], "0", vec![ @@ -94,7 +96,7 @@ fn test_div_self() { #[test] fn test_div_le_iff_rev() { - convexify_check_with_domain( + pre_dcp_check_with_domain( vec![("x", domain::pos_dom())], "0", vec![ @@ -104,7 +106,7 @@ fn test_div_le_iff_rev() { #[test] fn test_log_div_rev_obj() { - convexify_check_with_domain( + pre_dcp_check_with_domain( vec![("x", domain::pos_dom())], "(neg (sub (log (pow (var x) 2)) (log (var x))))", vec![ @@ -113,21 +115,23 @@ fn test_log_div_rev_obj() { #[test] fn test_geo_mean_fold() { - convexify_check_expression_with_domain( + pre_dcp_check_expression_with_domain( vec![("x", domain::pos_dom()), ("y", domain::pos_dom())], "(neg (sqrt (mul (var x) (var y))))"); } #[test] fn test_quad_over_lin_fold() { - convexify_check_expression_with_domain( + pre_dcp_check_expression_with_domain( vec![("x", domain::free_dom()), ("y", domain::pos_dom())], "(div (pow (var x) 2) (var y))"); } #[test] fn test_norm2_fold() { - convexify_check_expression_with_domain( + pre_dcp_check_expression_with_domain( vec![("x", domain::free_dom()), ("y", domain::free_dom())], "(sqrt (add (pow (var x) 2) (pow (var y) 2)))"); } + +} diff --git a/egg-pre-dcp/tests/test_stanford.rs b/egg-pre-dcp/tests/test_stanford.rs index 9c778bdd..c1e39f12 100644 --- a/egg-pre-dcp/tests/test_stanford.rs +++ b/egg-pre-dcp/tests/test_stanford.rs @@ -3,82 +3,105 @@ Tests from the additional exercises in the Stanford Convex Optimization course: https://github.com/cvxgrp/cvxbook_additional_exercises !*/ -use std::time::Instant; +mod test_stanford { use egg_pre_dcp::domain; use egg_pre_dcp::test_util::{*}; -// TODO: Failing because of qol curvature check simplification. +// Exercise 3.36 (c). #[test] -fn test_3_32() { - // 1 / (x * y) = (x^-0.5)^2 / y - // = qol(x^-0.5, y) (this is just one possibility) - // convexify_check_expression_with_domain( - // vec![("x", domain::pos_dom()), ("y", domain::pos_dom())], - // "(div 1 (mul (var x) (var y)))"); +fn test_stan1() { + // log(e^(2x + 3) + e^(4y + 5)) = lse(2x + 3, 4y + 5) + pre_dcp_check_expression_with_domain_and_print( + vec![("x", domain::free_dom()), ("y", domain::free_dom())], + "(log (add (exp (add (mul 2 (var x)) 3)) (exp (add (mul 4 (var y)) 5))))"); } -// TODO: Failing because of norm2 curvature check simplification. +// Exercise 3.38 (e). #[test] -fn test_3_33() { - // sqrt(1 + x^4 / y) = sqrt(1^2 + (x^2 / y)^2) - // ... = norm2(1, qol(x, y)) - // NOTE(RFM): Constant unfolding issue: x cannot be in the free domain as - // otherwise we cannot convert x^4 to (x^2)^2. - // convexify_check_expression_with_domain( - // vec![("x", domain::nonneg_dom()), ("y", domain::pos_dom())], - // "(sqrt (add 1 (div (pow (var x) 4) (var y))))" - // ); +fn test_stan2() { + // (sqrt(x) + sqrt(y))^2 = x + y + 2sqrt(xy) + // ... = x + y + 2geo(x, y) + pre_dcp_check_expression_with_domain_and_print( + vec![("x", domain::pos_dom()), ("y", domain::pos_dom())], + "(neg (pow (add (sqrt (var x)) (sqrt (var y))) 2))"); } -// TODO: Failing because of norm2 curvature check simplification. +// Exercise 3.67 with n = 3. #[test] -fn test_3_36_a() { - // sqrt(1 + 4x^2 + 16y^2) = sqrt((2x)^2 + (sqrt(1^2 + (4y)^2))^2) - // ... = norm2(2x, norm2(1, 4y)) - // convexify_check_expression_with_domain( - // vec![("x", domain::nonneg_dom()), ("y", domain::nonneg_dom())], - // "(sqrt (add 1 (add (mul 4 (pow (var x) 2)) (mul 16 (pow (var y) 2)))))"); +fn test_stan3() { + pre_dcp_check_expression_with_domain_and_print( + vec![("x", domain::pos_dom()), ("y", domain::pos_dom()), ("z", domain::pos_dom())], + "(neg (pow (add (sqrt (var x)) (add (sqrt (var y)) (sqrt (var z)))) 2))"); } +// Exercise 3.67 with n = 4. #[test] -fn test_3_36_c() { - // log(e^(2x + 3) + e^(4y + 5)) = lse(2x + 3, 4y + 5) - convexify_check_expression_with_domain( - vec![("x", domain::free_dom()), ("y", domain::free_dom())], - "(log (add (exp (add (mul 2 (var x)) 3)) (exp (add (mul 4 (var y)) 5))))"); +fn test_stan4() { + pre_dcp_check_expression_with_domain_and_print( + vec![("x", domain::pos_dom()), ("y", domain::pos_dom()), ("z", domain::pos_dom()), ("w", domain::pos_dom())], + "(neg (pow (add (sqrt (var x)) (add (sqrt (var y)) (add (sqrt (var z)) (sqrt (var w))))) 2))"); } -#[test] -fn test_3_38_e() { - // (sqrt(x) + sqrt(y))^2 = x + y + 2sqrt(xy) - // ... = x + y + 2geo(x, y) - convexify_check_expression_with_domain( - vec![("x", domain::pos_dom()), ("y", domain::pos_dom())], - "(neg (pow (add (sqrt (var x)) (sqrt (var y))) 2))"); +// TODO: Failing because of qol curvature check simplification. +// #[test] +// fn test_3_32() { +// // 1 / (x * y) = (x^-0.5)^2 / y +// // = qol(x^-0.5, y) (this is just one possibility) +// pre_dcp_check_expression_with_domain( +// vec![("x", domain::pos_dom()), ("y", domain::pos_dom())], +// "(div 1 (mul (var x) (var y)))"); +// } + +// TODO: Failing because of norm2 curvature check simplification. +// #[test] +// fn test_3_33() { +// // sqrt(1 + x^4 / y) = sqrt(1^2 + (x^2 / y)^2) +// // ... = norm2(1, qol(x, y)) +// pre_dcp_check_expression_with_domain( +// vec![("x", domain::nonneg_dom()), ("y", domain::pos_dom())], +// "(sqrt (add 1 (div (pow (var x) 4) (var y))))" +// ); +// } + +// TODO: Failing because of norm2 curvature check simplification. +// #[test] +// fn test_3_36_a() { +// // sqrt(1 + 4x^2 + 16y^2) = sqrt((2x)^2 + (sqrt(1^2 + (4y)^2))^2) +// // ... = norm2(2x, norm2(1, 4y)) +// pre_dcp_check_expression_with_domain( +// vec![("x", domain::nonneg_dom()), ("y", domain::nonneg_dom())], +// "(sqrt (add 1 (add (mul 4 (pow (var x) 2)) (mul 16 (pow (var y) 2)))))"); +// } + } -fn test_3_67_aux(n: usize, node_limit: usize) { +mod test_stanford_hard { + +use egg_pre_dcp::domain; +use egg_pre_dcp::test_util::{*}; + +fn test_stan_hard_aux(n: usize, node_limit: usize) { // Generalizaiton of 3.28. Works for n = 3,4,5,6,7 - // (sqrt(x_1) + ... + sqrt(x_n))^2 + // (sqrt(x_1) + ... + sqrt(x_n))^2 // ... = sum_{i <= n} x_i + 2 * sum_{i < j <= n} geo(x_i, x_j) let build_domain = |n| { - if n < 2 { - panic!("n must be >= 2"); + if n < 2 { + panic!("n must be >= 2"); } let mut v = Vec::new(); - for i in 0..n { - v.push((format!("x{}", i), domain::pos_dom())); + for i in 0..n { + v.push((format!("x{}", i), domain::pos_dom())); } v }; let build_term = |n| { - if n < 2 { - panic!("n must be >= 2"); + if n < 2 { + panic!("n must be >= 2"); } let mut sqrts = Vec::new(); - for i in 0..n { - sqrts.push(format!("(sqrt (var x{}))", i)); + for i in 0..n { + sqrts.push(format!("(sqrt (var x{}))", i)); } let last = sqrts[n-1].clone(); let before_last = sqrts[n-2].clone(); @@ -90,32 +113,21 @@ fn test_3_67_aux(n: usize, node_limit: usize) { format!("(neg (pow {} 2))", t) }; let domain_pre = build_domain(n).clone(); - let domain = + let domain = domain_pre .iter() .map(|(s,d)| (s.as_str(), d.clone())) .collect::>(); - convexify_check_expression_with_domain_and_node_limit(domain, &build_term(n), node_limit); + pre_dcp_check_expression_with_domain_and_node_limit(domain, &build_term(n), node_limit); } +// Exercise 3.67. #[test] -fn test_3_67() { - // for n in 2..11 { - // let now = Instant::now(); - // { - // test_3_67_aux(n, 20000 * n); - // } - // let elapsed = now.elapsed(); - // println!("Time for 3.67 (n={}): {:.2?}", n, elapsed); - // } - // node_limit = 20000 * n - // n=2: 32359 nodes, 8 steps, 7.82s - // n=3: 52870 nodes, 28 steps, 23.97s - // n=4: 63224 nodes, 37 steps, 31.05s - // n=5: 102470 nodes, 87 steps, 46.83s - // n=6: 106484 nodes, 132 steps, 129.44s - // n=7: 135393 nodes, 282 steps, 152.53s - // n=8: 157549 nodes, 566 steps, 269.90s - // n=9: 174177 nodes, 511 steps, 307.08s - // n=10: 200372 nodes, 716 steps, 173.13s -} \ No newline at end of file +fn test_stan_hard() { + let max_n = 4; + for n in 2..max_n { + test_stan_hard_aux(n, 20000 * n); + } +} + +} diff --git a/scripts/evaluation/egg-pre-dcp-options/README.md b/scripts/evaluation/egg-pre-dcp-options/README.md new file mode 100644 index 00000000..816379ad --- /dev/null +++ b/scripts/evaluation/egg-pre-dcp-options/README.md @@ -0,0 +1,14 @@ +Here, we compare `egg-pre-dcp` with different options (iterative node limits and stop-on-success). + +Run this from the root of the repository. + +To extract raw data, build `egg-pre-dcp`. Then choose `` in `scripts/evaluation/data` to store +the results. I usually give it a date, e.g. `2024_02_17_egg-pre-dcp-options`. Then run the +following. +``` +./scripts/evaluation/egg-pre-dcp-options/run_all.sh +``` +To read all the numerical values, and average them out, and plot the results, run the following. +``` +python3 scripts/evaluation/egg-pre-dcp-options/extract_stats.py +``` diff --git a/scripts/evaluation/egg-pre-dcp-options/benchmark.txt b/scripts/evaluation/egg-pre-dcp-options/benchmark.txt new file mode 100644 index 00000000..a47ff361 --- /dev/null +++ b/scripts/evaluation/egg-pre-dcp-options/benchmark.txt @@ -0,0 +1,32 @@ +gp1 +gp2 +gp3 +gp4 +gp5 +gp6 +gp7 +gp8 +gp9 +agp1 +agp2 +agp3 +agp4 +qcp1 +qcp2 +qcp3 +qcp4 +stan1 +stan2 +stan3 +stan4 +quiz1 +quiz2 +quiz3 +quiz4 +quiz5 +quiz6 +quiz7 +quiz8 +quiz9 +quiz10 +main_example diff --git a/scripts/evaluation/egg-pre-dcp-options/extract_stats.py b/scripts/evaluation/egg-pre-dcp-options/extract_stats.py new file mode 100644 index 00000000..78280776 --- /dev/null +++ b/scripts/evaluation/egg-pre-dcp-options/extract_stats.py @@ -0,0 +1,133 @@ +import sys +import re + +def extract_test_results(file_path): + test_results = {} + + with open(file_path, 'r') as file: + data = file.read() + + # Define regex patterns. + test_name_pattern = r'test_(\w+)::(\w+)' + total_time_pattern = r'Total time: (\d+) ms\.' + e_graph_time_pattern = r'E-graph building time: (\d+) ms\.' + step_extraction_time_pattern = r'Step extraction time: (\d+) ms\.' + num_of_nodes_pattern = r'Succeeded with node limit (\d+) \(using (\d+) nodes\).' + num_steps_pattern = r'Number of steps: (\d+)\.' + best_term_size_pattern = r'Best term size: (\d+)\.' + best_num_variables_pattern = r'Best number of variables: (\d+)\.' + + failure_pattern = r'failures:(.*?)test_(\w+)' + + # Extract test names and corresponding total times or failure status. + test_names = re.findall(test_name_pattern, data) + test_names = [x[1] for x in test_names] + total_times = re.findall(total_time_pattern, data) + e_graph_times = re.findall(e_graph_time_pattern, data) + step_extraction_times = re.findall(step_extraction_time_pattern, data) + num_of_nodes = re.findall(num_of_nodes_pattern, data) + num_of_nodes = [int(x[1]) for x in num_of_nodes] + num_steps = re.findall(num_steps_pattern, data) + best_term_sizes = re.findall(best_term_size_pattern, data) + best_num_variables = re.findall(best_num_variables_pattern, data) + + failures = re.findall(failure_pattern, data, re.DOTALL) + failed_tests = ["test_" + test[1] for test in failures] + + print(num_of_nodes) + print(len(test_names), len(total_times)) + assert len(test_names) == len(total_times) + len(failed_tests) + assert len(total_times) == len(e_graph_times) + assert len(total_times) == len(step_extraction_times) + assert len(total_times) == len(num_of_nodes) + assert len(total_times) == len(num_steps) + assert len(total_times) == len(best_term_sizes) + assert len(total_times) == len(best_num_variables) + + # Match test names with their total times or failure status. + i = 0 + for test_name in test_names: + if test_name in failed_tests: + test_results[test_name] = { + 'total_time': None, + 'e_graph_time': None, + 'step_extraction_time': None, + 'num_of_nodes': None, + 'steps': None, + 'term_size': None, + 'num_variables': None + } + else: + test_results[test_name] = { + 'total_time': int(total_times[i]), + 'e_graph_time': int(e_graph_times[i]), + 'step_extraction_time': int(step_extraction_times[i]), + 'num_of_nodes': num_of_nodes[i], + 'steps': int(num_steps[i]), + 'term_size': int(best_term_sizes[i]), + 'num_variables': int(best_num_variables[i]) + } + i += 1 + + return test_results + +# Change this to the path where the result of `run_all.sh` is. +data_path = 'scripts/evaluation/data/' + sys.argv[1] + '/' + +files = { + "stop" : "stop_on_success.txt", + "iter" : "iterative.txt", +} + +results = {} + +for key, value in files.items(): + file_path = data_path + value + results[key] = extract_test_results(file_path) + +benchmark_path = 'scripts/evaluation/egg-pre-dcp-options/benchmark.txt' +benchmark = [] +with open(benchmark_path, 'r') as file: + benchmark = file.readlines() + +print("---- Table with key results on stop and iter ----") +with open(data_path + 'summary.csv', 'w') as file: + for prob in benchmark: + prob = "test_" + prob.strip() + prob_results_stop = results["stop"][prob] + prob_results_iter = results["iter"][prob] + if prob_results_iter["steps"] < 1 or prob_results_stop["steps"] < 1: + continue + # Problem & Time (stop) & Steps (stop) & Term size (stop) & Time (iter) & Steps (iter) & Term size (iter) + values = [prob[5:]] + values += ["-" if prob_results_stop["total_time"] is None else str(prob_results_stop["total_time"]) + " ms"] + values += ["-" if prob_results_stop["steps"] is None else str(prob_results_stop["steps"])] + values += ["-" if prob_results_stop["term_size"] is None else str(prob_results_stop["term_size"])] + values += ["-" if prob_results_iter["total_time"] is None else str(prob_results_iter["total_time"]) + " ms"] + values += ["-" if prob_results_iter["steps"] is None else str(prob_results_iter["steps"])] + values += ["-" if prob_results_iter["term_size"] is None else str(prob_results_iter["term_size"])] + file.write(','.join(values) + '\n') + # To copy-paste into LaTeX. + print(' & '.join(values) + ' \\\\') + print('\\hline') + +for key, _ in files.items(): + print("----- " + key + " -----") + with open(data_path + key + '_summary.csv', 'w') as file: + for prob in benchmark: + prob = "test_" + prob.strip() + prob_results = results[key][prob] + if prob_results["steps"] < 1: + continue + # Problem & Time & Build time & Extraction time & Number of nodes & Steps & Term size + values = [prob[5:]] + values += ["-" if prob_results["total_time"] is None else str(prob_results["total_time"]) + " ms"] + values += ["-" if prob_results["e_graph_time"] is None else str(prob_results["e_graph_time"]) + " ms"] + values += ["-" if prob_results["step_extraction_time"] is None else str(prob_results["step_extraction_time"]) + " ms"] + values += ["-" if prob_results["num_of_nodes"] is None else str(prob_results["num_of_nodes"])] + values += ["-" if prob_results["steps"] is None else str(prob_results["steps"])] + values += ["-" if prob_results["term_size"] is None else str(prob_results["term_size"])] + file.write(','.join(values) + '\n') + # To copy-paste into LaTeX. + print(' & '.join(values) + ' \\\\') + print('\\hline') diff --git a/scripts/evaluation/egg-pre-dcp-options/inspect_outliers.sh b/scripts/evaluation/egg-pre-dcp-options/inspect_outliers.sh new file mode 100755 index 00000000..fa650abe --- /dev/null +++ b/scripts/evaluation/egg-pre-dcp-options/inspect_outliers.sh @@ -0,0 +1,16 @@ +#!/bin/bash + +# There are 6 examples with shorter explanations on interative mode. We expected stop-on-success to +# yield shorter explanations. We isolate them here to understand what is happening. + +cd egg-pre-dcp + +# This is the smallest. + +cargo test test_stanford::test_stan2 -- --exact --test-threads=1 --nocapture +RUSTFLAGS="--cfg stop_on_success" cargo test test_stanford::test_stan2 -- --exact --test-threads=1 --nocapture + +# This is the largest. + +# cargo test test_dgp::test_gp6 -- --exact --test-threads=1 --nocapture +# RUSTFLAGS="--cfg stop_on_success" cargo test test_dgp::test_gp6 -- --exact --test-threads=1 --nocapture diff --git a/scripts/evaluation/egg-pre-dcp-options/run_all.sh b/scripts/evaluation/egg-pre-dcp-options/run_all.sh new file mode 100755 index 00000000..1c34c344 --- /dev/null +++ b/scripts/evaluation/egg-pre-dcp-options/run_all.sh @@ -0,0 +1,8 @@ +#!/bin/bash + +OUT=$1 + +mkdir -p ./evaluation/data/$OUT + +./evaluation/egg-pre-dcp-options/run_stop_on_success.sh ./evaluation/data/$OUT/stop_on_success.txt; +./evaluation/egg-pre-dcp-options/run_iterative.sh ./evaluation/data/$OUT/iterative.txt; diff --git a/scripts/evaluation/egg-pre-dcp-options/run_iterative.sh b/scripts/evaluation/egg-pre-dcp-options/run_iterative.sh new file mode 100755 index 00000000..d264a8ac --- /dev/null +++ b/scripts/evaluation/egg-pre-dcp-options/run_iterative.sh @@ -0,0 +1,14 @@ +#!/bin/bash + +SCRIPT=$(readlink -f $0) +DIR=`dirname $SCRIPT` +OUT=$1 + +cd egg-pre-dcp + +cargo test test_dgp:: -- --test-threads=1 --nocapture >> ../$OUT +cargo test test_almost_dgp:: -- --test-threads=1 --nocapture >> ../$OUT +cargo test test_dqcp:: -- --test-threads=1 --nocapture >> ../$OUT +cargo test test_stanford:: -- --test-threads=1 --nocapture >> ../$OUT +cargo test test_quiz:: -- --test-threads=1 --nocapture >> ../$OUT +cargo test test_main_example:: -- --test-threads=1 --nocapture >> ../$OUT diff --git a/scripts/evaluation/egg-pre-dcp-options/run_stop_on_success.sh b/scripts/evaluation/egg-pre-dcp-options/run_stop_on_success.sh new file mode 100755 index 00000000..9474ce7b --- /dev/null +++ b/scripts/evaluation/egg-pre-dcp-options/run_stop_on_success.sh @@ -0,0 +1,14 @@ +#!/bin/bash + +SCRIPT=$(readlink -f $0) +DIR=`dirname $SCRIPT` +OUT=$1 + +cd egg-pre-dcp + +RUSTFLAGS="--cfg stop_on_success" cargo test test_dgp:: -- --test-threads=1 --nocapture >> ../$OUT +RUSTFLAGS="--cfg stop_on_success" cargo test test_almost_dgp:: -- --test-threads=1 --nocapture >> ../$OUT +RUSTFLAGS="--cfg stop_on_success" cargo test test_dqcp:: -- --test-threads=1 --nocapture >> ../$OUT +RUSTFLAGS="--cfg stop_on_success" cargo test test_stanford:: -- --test-threads=1 --nocapture >> ../$OUT +RUSTFLAGS="--cfg stop_on_success" cargo test test_quiz:: -- --test-threads=1 --nocapture >> ../$OUT +RUSTFLAGS="--cfg stop_on_success" cargo test test_main_example:: -- --test-threads=1 --nocapture >> ../$OUT diff --git a/scripts/evaluation/egg-pre-dcp-options/run_with_node_limit.sh b/scripts/evaluation/egg-pre-dcp-options/run_with_node_limit.sh new file mode 100755 index 00000000..6973cbb5 --- /dev/null +++ b/scripts/evaluation/egg-pre-dcp-options/run_with_node_limit.sh @@ -0,0 +1,15 @@ +#!/bin/bash + +SCRIPT=$(readlink -f $0) +DIR=`dirname $SCRIPT` +LIM=$1 +OUT=$2_$LIM.txt + +cd egg-pre-dcp + +EGG_PRE_DCP_NODE_LIMIT="$LIM" cargo test test_dgp:: -- --test-threads=1 --nocapture >> ../$OUT +EGG_PRE_DCP_NODE_LIMIT="$LIM" cargo test test_almost_dgp:: -- --test-threads=1 --nocapture >> ../$OUT +EGG_PRE_DCP_NODE_LIMIT="$LIM" cargo test test_dqcp:: -- --test-threads=1 --nocapture >> ../$OUT +EGG_PRE_DCP_NODE_LIMIT="$LIM" cargo test test_stanford:: -- --test-threads=1 --nocapture >> ../$OUT +EGG_PRE_DCP_NODE_LIMIT="$LIM" cargo test test_quiz:: -- --test-threads=1 --nocapture >> ../$OUT +EGG_PRE_DCP_NODE_LIMIT="$LIM" cargo test test_main_example:: -- --test-threads=1 --nocapture >> ../$OUT diff --git a/scripts/evaluation/lean-pre-dcp/README.md b/scripts/evaluation/lean-pre-dcp/README.md new file mode 100644 index 00000000..01bbc1c5 --- /dev/null +++ b/scripts/evaluation/lean-pre-dcp/README.md @@ -0,0 +1,14 @@ +Here, we compare `egg-pre-dcp` times and `pre_dcp` (Lean) times to calculate the proof replay +overhead. + +Run this from the root of the repository. + +To extract raw data, build `CvxLean`. Then choose `` in `scripts/evaluation/data` to store the +results. I usually give it a date, e.g. `2024_02_17_lean-pre-dcp`. Then run the following. +``` +./scripts/evaluation/lean-pre-dcp/run_all.sh +``` +To read all the numerical values, and average them out, and plot the results, run the following. +``` +python3 scripts/evaluation/lean-pre-dcp/extract_stats.py +``` diff --git a/scripts/evaluation/lean-pre-dcp/extract_stats.py b/scripts/evaluation/lean-pre-dcp/extract_stats.py new file mode 100644 index 00000000..8d29396e --- /dev/null +++ b/scripts/evaluation/lean-pre-dcp/extract_stats.py @@ -0,0 +1,199 @@ +import sys +import re +import matplotlib.pyplot as plt +import numpy as np +import heapq + +def extract_stats_from_file(file_path): + with open(file_path, 'r') as file: + egg_times = [] + command_times = [] + term_sizes = [] + numbers_of_steps = [] + term_jsons = [] + for line in file: + match = re.search(r'Egg time: (\d+) ms', line) + if match: + egg_time = int(match.group(1)) + egg_times.append(egg_time) + match = re.search(r'Command time: (\d+) ms', line) + if match: + command_time = int(match.group(1)) + command_times.append(command_time) + match = re.search(r'Term size: (\d+)', line) + if match: + term_size = int(match.group(1)) + term_sizes.append(term_size) + match = re.search(r'Number of steps: (\d+)', line) + if match: + number_of_steps = int(match.group(1)) + numbers_of_steps.append(number_of_steps) + match = re.search(r'Term JSON: (.+)', line) + if match: + term_json = match.group(1) + term_jsons.append(term_json) + assert len(egg_times) == len(command_times) == len(term_sizes) == len(numbers_of_steps) == len(term_jsons) + return egg_times, command_times, term_sizes, numbers_of_steps, term_jsons + +def plot_stats(fn, xs, xs_label, egg_times, command_times): + plt.scatter(xs, egg_times, color='orange', label='egg', marker='x') + plt.scatter(xs, command_times, color='blue', label='command', marker='x') + plt.xlabel(xs_label) + plt.ylabel('time (ms)') + plt.locator_params(axis="both", integer=True, tight=True) + plt.legend() + plt.savefig(fn) + plt.show() + +def plot_stats1(fn, xs, xs_label, ys, ys_label): + plt.scatter(xs, ys, color='black', marker='x') + plt.xlabel(xs_label) + plt.ylabel(ys_label) + plt.locator_params(axis="both", integer=True, tight=True) + plt.legend() + plt.savefig(fn) + plt.show() + +eval_dir = './scripts/evaluation/data/' + sys.argv[1] +print(eval_dir) +file_paths = [eval_dir + '/lean_pre_dcp_test_out_{}.txt'.format(i) for i in range(1, 6)] + +egg_times = None +command_times = None +term_sizes = None +numbers_of_steps = None +term_jsons = None +for file_path in file_paths: + stats = extract_stats_from_file(file_path) + if egg_times is None: + egg_times = np.array(stats[0]) + else: + egg_times += np.array(stats[0]) + if command_times is None: + command_times = np.array(stats[1]) + else: + command_times += np.array(stats[1]) + if term_sizes is None: + term_sizes = np.array(stats[2]) + else: + term_sizes += np.array(stats[2]) + if numbers_of_steps is None: + numbers_of_steps = np.array(stats[3]) + else: + numbers_of_steps += np.array(stats[3]) + if term_jsons is None: + term_jsons = stats[4] + else: + for i in range(len(term_jsons)): + assert term_jsons[i] == stats[4][i] + + print('Data points:', len(egg_times)) + +egg_times = egg_times / len(file_paths) +command_times = command_times / len(file_paths) +term_sizes = term_sizes / len(file_paths) +numbers_of_steps = numbers_of_steps / len(file_paths) + +# Order matters, this is the order in which tests are run: +# "Unit", "MainExample", "DGP", "AlmostDGP", "DQCP", "Quiz", "Stanford" +unit_start_idx = 0 +unit_count = 113 +main_example_start_idx = unit_count +main_example_count = 1 +dgp_start_idx = unit_count + main_example_count +dgp_count = 9 +adgp_start_idx = unit_count + main_example_count + dgp_count +adgp_count = 4 +dqcp_start_idx = unit_count + main_example_count + dgp_count + adgp_count +dqcp_count = 4 +quiz_start_idx = unit_count + main_example_count + dgp_count + adgp_count + dqcp_count +quiz_count = 10 +stan_start_idx = unit_count + main_example_count + dgp_count + adgp_count + dqcp_count + quiz_count +stan_count = 4 + +unit_idxs = [i for i in range(main_example_start_idx,main_example_start_idx+main_example_count)] + [i for i in range(unit_start_idx, unit_start_idx + unit_count)] +quiz_idxs = [i for i in range(quiz_start_idx, quiz_start_idx + quiz_count)] +stan_idxs = [i for i in range(stan_start_idx, stan_start_idx + stan_count)] +dgp_idxs = [i for i in range(dgp_start_idx, dgp_start_idx + dgp_count)] + [i for i in range(adgp_start_idx, adgp_start_idx + adgp_count)] +dqcp_idxs = [i for i in range(dqcp_start_idx, dqcp_start_idx + dqcp_count)] + +unit_idxs_len = len(unit_idxs) +print('Unit count:', unit_idxs_len) +quiz_idxs_len = len(quiz_idxs) +print('Quiz count:', quiz_idxs_len) +stan_idxs_len = len(stan_idxs) +print('Stan count:', stan_idxs_len) +dgp_idxs_len = len(dgp_idxs) +print('DGP count:', dgp_idxs_len) +dqcp_idxs_len = len(dqcp_idxs) +print('QCP count:', dqcp_idxs_len) +total_len = unit_idxs_len + quiz_idxs_len + stan_idxs_len + dgp_idxs_len + dqcp_idxs_len + +problem_names = [f"unit_{i}" for i in range(1, unit_count + 1)] +problem_names += [f"main_example"] +problem_names += [f"dgp_{i}" for i in range(1, dgp_count + 1)] +problem_names += [f"adgp_{i}" for i in range(1, adgp_count + 1)] +problem_names += [f"dqcp_{i}" for i in range(1, dqcp_count + 1)] +problem_names += [f"quiz_{i}" for i in range(1, quiz_count + 1)] +problem_names += [f"stan_{i}" for i in range(1, stan_count + 1)] + +unit_term_sizes = term_sizes[unit_idxs] +quiz_term_sizes = term_sizes[quiz_idxs] +stan_term_sizes = term_sizes[stan_idxs] +dgp_term_sizes = term_sizes[dgp_idxs] +qcp_term_sizes = term_sizes[dqcp_idxs] + +unit_min_term_size = np.min(unit_term_sizes) +unit_max_term_size = np.max(unit_term_sizes) +quiz_min_term_size = np.min(quiz_term_sizes) +quiz_max_term_size = np.max(quiz_term_sizes) +stan_min_term_size = np.min(stan_term_sizes) +stan_max_term_size = np.max(stan_term_sizes) +dgp_min_term_size = np.min(dgp_term_sizes) +dgp_max_term_size = np.max(dgp_term_sizes) +qcp_min_term_size = np.min(qcp_term_sizes) +qcp_max_term_size = np.max(qcp_term_sizes) + +print('Unit size range:', unit_min_term_size, unit_max_term_size) +print('Quiz size range:', quiz_min_term_size, quiz_max_term_size) +print('Stan size range:', stan_min_term_size, stan_max_term_size) +print('DGP size range:', dgp_min_term_size, dgp_max_term_size) +print('QCP size range:', qcp_min_term_size, qcp_max_term_size) + +print('Egg time average: ', np.average(egg_times)) +print('Command time average: ', np.average(command_times)) +print('Steps average', np.average(numbers_of_steps)) +print('Lean time average: ', np.average(command_times - egg_times)) +command_times_per_step = np.array([0 if numbers_of_steps[i] == 0 else command_times[i] / numbers_of_steps[i] for i in range(len(command_times))]) +print('Average command time per step: ', np.average(command_times_per_step)) + +# Remove some outliers. +to_remove = 8 +idxs_to_remove = heapq.nlargest(to_remove, range(len(command_times)), key=command_times.__getitem__) +for i in range(to_remove): + idx = idxs_to_remove[i] + print('Outlier time:', command_times[idx]) + print('Removing outlier: ' + problem_names[idx]) + +egg_times = np.delete(egg_times, idxs_to_remove) +command_times = np.delete(command_times, idxs_to_remove) +term_sizes = np.delete(term_sizes, idxs_to_remove) +numbers_of_steps = np.delete(numbers_of_steps, idxs_to_remove) + +# Remove problems with 0 steps. +idxs_to_remove = [] +for i in range(len(numbers_of_steps)): + if numbers_of_steps[i] == 0: + idxs_to_remove.append(i) + print('Removing problem with 0 steps: ' + problem_names[i]) + +egg_times = np.delete(egg_times, idxs_to_remove) +command_times = np.delete(command_times, idxs_to_remove) +term_sizes = np.delete(term_sizes, idxs_to_remove) +numbers_of_steps = np.delete(numbers_of_steps, idxs_to_remove) + +lean_times = command_times - egg_times + +plot_stats(eval_dir + "/pre_dcp_eval_sizes_time.png", term_sizes, 'term size', egg_times, command_times) +plot_stats(eval_dir + "/pre_dcp_eval_steps_time.png", numbers_of_steps, 'number of steps', egg_times, command_times) +plot_stats1(eval_dir + "/pre_dcp_eval_sizes_steps.png", term_sizes, 'term size', numbers_of_steps, 'number of steps') diff --git a/scripts/evaluation/lean-pre-dcp/run.sh b/scripts/evaluation/lean-pre-dcp/run.sh new file mode 100755 index 00000000..c78fac66 --- /dev/null +++ b/scripts/evaluation/lean-pre-dcp/run.sh @@ -0,0 +1,15 @@ +#!/bin/bash + +OUT=$1 + +rm -rf .lake/build/lib/CvxLean/Test/PreDCP + +TMP=$(cat CvxLeanTest.lean) + +L=("Unit" "MainExample" "DGP" "AlmostDGP" "DQCP" "Quiz" "Stanford") +for s in ${L[@]}; do + echo "import CvxLean.Test.PreDCP.$s" > CvxLeanTest.lean + lake build CvxLeanTest >> $OUT +done + +echo $TMP > CvxLeanTest.lean diff --git a/scripts/evaluation/lean-pre-dcp/run_all.sh b/scripts/evaluation/lean-pre-dcp/run_all.sh new file mode 100755 index 00000000..a83fee55 --- /dev/null +++ b/scripts/evaluation/lean-pre-dcp/run_all.sh @@ -0,0 +1,11 @@ +#!/bin/bash + +OUT=$1 + +mkdir -p ./evaluation/data/$OUT + +./evaluation/lean-pre-dcp/run.sh ./evaluation/data/$OUT/lean_pre_dcp_test_out_1.txt; +./evaluation/lean-pre-dcp/run.sh ./evaluation/data/$OUT/lean_pre_dcp_test_out_2.txt; +./evaluation/lean-pre-dcp/run.sh ./evaluation/data/$OUT/lean_pre_dcp_test_out_3.txt; +./evaluation/lean-pre-dcp/run.sh ./evaluation/data/$OUT/lean_pre_dcp_test_out_4.txt; +./evaluation/lean-pre-dcp/run.sh ./evaluation/data/$OUT/lean_pre_dcp_test_out_5.txt