diff --git a/.github/workflows/docs.yml b/.github/workflows/docs.yml index 194ff989..08de8dd1 100644 --- a/.github/workflows/docs.yml +++ b/.github/workflows/docs.yml @@ -3,9 +3,6 @@ name: docs on: push: branches: ["main"] - - pull_request: - branches: ["main"] workflow_dispatch: @@ -34,10 +31,11 @@ jobs: with: repository: verified-optimization/CvxLean path: CvxLean - - name: Build egg-pre-dcp and CvxLean + - name: Build egg-pre-dcp, CvxLean, and CvxLeanTest working-directory: CvxLean run: | ./build.sh + lake build CvxLeanTest - name: Create dummy docs project run: | # Taken from https://github.com/leanprover-community/mathlib4_docs diff --git a/.github/workflows/main.yml b/.github/workflows/main.yml index b1f480b2..233009f6 100644 --- a/.github/workflows/main.yml +++ b/.github/workflows/main.yml @@ -38,6 +38,9 @@ jobs: - name: Build egg-pre-dcp and CvxLean run: | ./build.sh + - name: Check style + run: | + ./scripts/style/check_style_all.sh - name: Test egg-pre-dcp working-directory: egg-pre-dcp run: | diff --git a/.gitignore b/.gitignore index 4c143eca..e2c14b1a 100644 --- a/.gitignore +++ b/.gitignore @@ -13,4 +13,4 @@ /scripts/evaluation/data /plots -/CvxLean/Tactic/Solver/Mosek/Path.lean +/CvxLean/Command/Solve/Mosek/Path.lean diff --git a/CvxLean/Command/Solve/Conic.lean b/CvxLean/Command/Solve/Conic.lean index 3d23ea2d..be75b0b4 100644 --- a/CvxLean/Command/Solve/Conic.lean +++ b/CvxLean/Command/Solve/Conic.lean @@ -88,7 +88,8 @@ unsafe def solutionDataFromProblemData (minExpr : MinimizationExpr) (data : Prob | Except.error err => throwSolveError err /-- -/ -unsafe def exprFromSolutionData (minExpr : MinimizationExpr) (solData : SolutionData) : MetaM Expr := do +unsafe def exprFromSolutionData (minExpr : MinimizationExpr) (solData : SolutionData) : + MetaM Expr := do let vars ← decomposeDomainInstantiating minExpr -- Generate solution of the correct shape. @@ -114,7 +115,7 @@ unsafe def exprFromSolutionData (minExpr : MinimizationExpr) (solData : Solution let exprs := (solPointExprArrayRaw.drop i).take n -- TODO: Code repetition. - let arrayExpr ← Lean.Expr.mkArray (mkConst ``Real) exprs + let arrayExpr ← Expr.mkArray (mkConst ``Real) exprs let arrayList ← mkAppM ``Array.toList #[arrayExpr] let v ← withLocalDeclD `i' (← mkAppM ``Fin #[toExpr n]) fun i' => do let i'' := mkApp2 (mkConst ``Fin.val) (toExpr n) i' @@ -128,12 +129,12 @@ unsafe def exprFromSolutionData (minExpr : MinimizationExpr) (solData : Solution -- Matrix. let mut exprs := #[] for j in [:m] do - let arrayExpr ← Lean.Expr.mkArray (mkConst ``Real) ((solPointExprArrayRaw.drop (i + j * n)).take n) + let arrayExpr ← Expr.mkArray (mkConst ``Real) + ((solPointExprArrayRaw.drop (i + j * n)).take n) let listExpr ← mkAppM ``Array.toList #[arrayExpr] exprs := exprs.push listExpr - let arrayListExpr ← - Lean.Expr.mkArray (← mkAppM ``List #[mkConst ``Real]) exprs + let arrayListExpr ← Expr.mkArray (← mkAppM ``List #[mkConst ``Real]) exprs -- List of list representing the matrix. let listListExpr ← mkAppM ``Array.toList #[arrayListExpr] diff --git a/CvxLean/Command/Solve/Float/Coeffs.lean b/CvxLean/Command/Solve/Float/Coeffs.lean index aba0c2e0..33bfbf8d 100644 --- a/CvxLean/Command/Solve/Float/Coeffs.lean +++ b/CvxLean/Command/Solve/Float/Coeffs.lean @@ -163,15 +163,16 @@ unsafe def unrollVectors (constraints : Expr) : MetaM (Array Expr) := do let idxExpr ← mkFinIdxExpr i n let ei := mkApp e idxExpr res := res.push (← mkAppM ``Real.zeroCone #[ei]) - -- Vector positive orthant cone. - | .app (.app (.app (.const ``Real.Vec.posOrthCone _) (.app (.const ``Fin _) n)) _) e => + -- Vector nonnegative orthant cone. + | .app (.app (.app (.const ``Real.Vec.nonnegOrthCone _) (.app (.const ``Fin _) n)) _) e => let n : Nat ← evalExpr Nat (mkConst ``Nat) n for i in [:n] do let idxExpr ← mkFinIdxExpr i n let ei := mkApp e idxExpr - res := res.push (← mkAppM ``Real.posOrthCone #[ei]) + res := res.push (← mkAppM ``Real.nonnegOrthCone #[ei]) -- Vector exponential cone. - | .app (.app (.app (.app (.app (.const ``Real.Vec.expCone _) (.app (.const ``Fin _) n)) _) a) b) c => + | .app (.app (.app (.app (.app + (.const ``Real.Vec.expCone _) (.app (.const ``Fin _) n)) _) a) b) c => let n : Nat ← evalExpr Nat (mkConst ``Nat) n for i in [:n] do let idxExpr ← mkFinIdxExpr i n @@ -277,7 +278,7 @@ unsafe def determineCoeffsFromExpr (minExpr : MinimizationExpr) : let res ← determineScalarCoeffsAux e p floatDomain data := data.addZeroConstraint res.1 res.2 idx := idx + 1 - | .app (.const ``Real.posOrthCone _) e => do + | .app (.const ``Real.nonnegOrthCone _) e => do let e ← realToFloat e let res ← determineScalarCoeffsAux e p floatDomain data := data.addPosOrthConstraint res.1 res.2 @@ -312,7 +313,7 @@ unsafe def determineCoeffsFromExpr (minExpr : MinimizationExpr) : let (ea, eb) ← determineScalarCoeffsAux e p floatDomain data := data.addSOConstraint ea eb idx := idx + 1 - | .app (.app (.app (.app (.app (.const ``Real.Matrix.posOrthCone _) + | .app (.app (.app (.app (.app (.const ``Real.Matrix.nonnegOrthCone _) (.app (.const ``Fin _) m)) (.app (.const ``Fin _) n)) _) _) e => do let m : Nat ← evalExpr Nat (mkConst ``Nat) m let n : Nat ← evalExpr Nat (mkConst ``Nat) n diff --git a/CvxLean/Command/Solve/Float/ProblemData.lean b/CvxLean/Command/Solve/Float/ProblemData.lean index d625b316..bbacd516 100644 --- a/CvxLean/Command/Solve/Float/ProblemData.lean +++ b/CvxLean/Command/Solve/Float/ProblemData.lean @@ -11,8 +11,10 @@ The procedure that creates `ProblemData` from an optimization problem can be fou namespace CvxLean -/-- Cones admitting scalar affine constraints. Note that the only cone that admits matrix affine -constraints is the PSD cone. -/ +/-- Cones admitting scalar affine constraints. The `PosOrth` type actually corresponds to the +nonnegative orthant, but at the solver level, it is usually called the positive orthant, so we use +that terminology here. Note that the only cone that admits matrix affine constraints is the PSD +cone. -/ inductive ScalarConeType | Zero | PosOrth | Exp | Q | QR diff --git a/CvxLean/Command/Solve/Float/RealToFloatCmd.lean b/CvxLean/Command/Solve/Float/RealToFloatCmd.lean index c1d6cb6b..1c841326 100644 --- a/CvxLean/Command/Solve/Float/RealToFloatCmd.lean +++ b/CvxLean/Command/Solve/Float/RealToFloatCmd.lean @@ -76,12 +76,12 @@ partial def realToFloat (e : Expr) : MetaM Expr := do environment extension. -/ @[command_elab addRealToFloatCommand] def elabAddRealToFloatCommand : CommandElab -| `(add_real_to_float : $real := $float) => - liftTermElabM do - let real ← elabTermAndSynthesize real.raw none - let float ← elabTermAndSynthesize float.raw none - addRealToFloatData { real := real, float := float } -| _ => throwUnsupportedSyntax + | `(add_real_to_float : $real := $float) => + liftTermElabM do + let real ← elabTermAndSynthesize real.raw none + let float ← elabTermAndSynthesize float.raw none + addRealToFloatData { real := real, float := float } + | _ => throwUnsupportedSyntax macro_rules | `(add_real_to_float $idents:funBinder* : $real := $float) => do diff --git a/CvxLean/Command/Solve/Mosek/CBF.lean b/CvxLean/Command/Solve/Mosek/CBF.lean index 27b9fe84..14f6621f 100644 --- a/CvxLean/Command/Solve/Mosek/CBF.lean +++ b/CvxLean/Command/Solve/Mosek/CBF.lean @@ -567,9 +567,8 @@ def addScalarConstraintsShiftCoord (p : Problem) (i : Nat) (v : Float) : Problem { p with scalarConstraintsShiftCoord := p.scalarConstraintsShiftCoord.stack ev} -def addPSDConstraintsScalarVariablesCoord - (p : Problem) (i : Nat) (eml : EncodedMatrixList) - : Problem := +def addPSDConstraintsScalarVariablesCoord (p : Problem) (i : Nat) (eml : EncodedMatrixList) : + Problem := let emll := EncodedMatrixListList.fromIndexAndEncodedMatrixList i eml { p with PSDConstraintsScalarVariablesCoord := p.PSDConstraintsScalarVariablesCoord.stack emll } diff --git a/CvxLean/Command/Solve/Mosek/Path.lean b/CvxLean/Command/Solve/Mosek/Path.lean index 2a117d02..ec18e75a 100644 --- a/CvxLean/Command/Solve/Mosek/Path.lean +++ b/CvxLean/Command/Solve/Mosek/Path.lean @@ -1,3 +1,4 @@ --- Change this to the location of the Mosek binary, e.g. --- /Users/alice/mosek/10.0/tools/platform/osxaarch64/bin -def mosekBinPath := "/Users/ramonfernandezmir/Documents/PhD-code/optimization/mosek/10.0/tools/platform/osxaarch64/bin" +/-! Change this to the location of the Mosek binary, e.g. +`/Users/alice/mosek/10.0/tools/platform/osxaarch64/bin` -/ + +def mosekBinPath := "" diff --git a/CvxLean/Command/Solve/Mosek/Sol.lean b/CvxLean/Command/Solve/Mosek/Sol.lean index 1b39eec4..f912d851 100644 --- a/CvxLean/Command/Solve/Mosek/Sol.lean +++ b/CvxLean/Command/Solve/Mosek/Sol.lean @@ -133,7 +133,8 @@ instance : ToString Result where "INDEX | NAME | AT | ACTIVITY | LOWER LIMIT | UPPER LIMIT | DUAL LOWER | DUAL UPPER \n" ++ (res.constraints.foldl (fun acc s => acc ++ (toString s) ++ "\n") "") ++ "\n" ++ "VARIABLES \n" ++ - "INDEX | NAME | AT | ACTIVITY | LOWER LIMIT | UPPER LIMIT | DUAL LOWER | DUAL UPPER | CONIC DUAL \n" ++ + "INDEX | NAME | AT | ACTIVITY | LOWER LIMIT | UPPER LIMIT | DUAL LOWER | DUAL UPPER | " ++ + "CONIC DUAL \n" ++ (res.vars.foldl (fun acc s => acc ++ (toString s) ++ "\n") "") ++ "\n" ++ "SYMMETRIC MATRIX VARIABLES \n" ++ "INDEX | NAME | I | J | PRIMAL | DUAL \n" ++ diff --git a/CvxLean/Command/Util/TimeCmd.lean b/CvxLean/Command/Util/TimeCmd.lean index 5759a7ac..ffc7e937 100644 --- a/CvxLean/Command/Util/TimeCmd.lean +++ b/CvxLean/Command/Util/TimeCmd.lean @@ -1,5 +1,10 @@ import Lean +/-! +Place `time_cmd` before any command to print the time taken by the command. Useful for simple +profiling. +-/ + open Lean Elab Command syntax (name := time_cmd) "time_cmd " command : command @@ -7,9 +12,9 @@ syntax (name := time_cmd) "time_cmd " command : command @[command_elab «time_cmd»] def evalTimeCmd : CommandElab := fun stx => match stx with | `(time_cmd $cmd) => do - let before ← BaseIO.toIO IO.monoMsNow - elabCommand cmd - let after ← BaseIO.toIO IO.monoMsNow - let diff := after - before - dbg_trace s!"Command time: {diff} ms" + let before ← BaseIO.toIO IO.monoMsNow + elabCommand cmd + let after ← BaseIO.toIO IO.monoMsNow + let diff := after - before + dbg_trace s!"Command time: {diff} ms" | _ => throwUnsupportedSyntax diff --git a/CvxLean/Demos/LT2024.lean b/CvxLean/Demos/LT2024.lean index 742b94a8..87c512ba 100644 --- a/CvxLean/Demos/LT2024.lean +++ b/CvxLean/Demos/LT2024.lean @@ -1,5 +1,10 @@ import CvxLean +/-! +This demo was used for Lean Together 2024, see the video here: + +-/ + namespace LT2024 noncomputable section @@ -70,9 +75,6 @@ def p₃ (Awall Aflr α β γ δ : ℝ) := c8 : γ ≤ d / w c9 : d / w ≤ δ - -set_option trace.Meta.debug true - equivalence* eqv₃/q₃ : p₃ 100 10 0.5 2 0.5 2 := by change_of_variables! (h') (h ↦ exp h') change_of_variables! (w') (w ↦ exp w') diff --git a/CvxLean/Examples/CovarianceEstimation.lean b/CvxLean/Examples/CovarianceEstimation.lean index 97940e8a..c16e8879 100644 --- a/CvxLean/Examples/CovarianceEstimation.lean +++ b/CvxLean/Examples/CovarianceEstimation.lean @@ -91,7 +91,7 @@ solve covEstimationConvex nₚ Nₚ αₚ yₚ -- -(-(Nₚ • log (sqrt ((2 * π) ^ nₚ)) + Nₚ • (-Vec.sum t.0 / 2)) + -- -(↑Nₚ * trace ((covarianceMatrix fun i => yₚ i) * Rᵀ) / 2)) -- subject to --- _ : Real.posOrthCone (αₚ - sum T.2) +-- _ : Real.nonnegOrthCone (αₚ - sum T.2) -- _ : Vec.expCone t.0 1 (diag Y.1) -- _ : -- PSDCone @@ -99,8 +99,8 @@ solve covEstimationConvex nₚ Nₚ αₚ yₚ -- let D := diagonal (diag Y.1); -- let X := fromBlocks D Z Zᵀ R; -- X) --- _ : Matrix.posOrthCone (T.2 - R) --- _ : Matrix.posOrthCone (T.2 + R) +-- _ : Matrix.nonnegOrthCone (T.2 - R) +-- _ : Matrix.nonnegOrthCone (T.2 + R) #eval covEstimationConvex.status -- "PRIMAL_AND_DUAL_FEASIBLE" #eval covEstimationConvex.value -- 14.124098 diff --git a/CvxLean/Examples/FittingSphere.lean b/CvxLean/Examples/FittingSphere.lean index 9c32c56c..378ced19 100644 --- a/CvxLean/Examples/FittingSphere.lean +++ b/CvxLean/Examples/FittingSphere.lean @@ -23,7 +23,7 @@ def mean {n : ℕ} (a : Fin n → ℝ) : ℝ := (1 / n) * ∑ i, (a i) `leastSquares_optimal_eq_mean`, following Marty Cohen's answer in https://math.stackexchange.com/questions/2554243. -/ lemma leastSquares_alt_objFun {n : ℕ} (hn : 0 < n) (a : Fin n → ℝ) (x : ℝ) : - (∑ i, ((a i - x) ^ 2)) = n * ((x - mean a) ^ 2 + (mean (a ^ 2) - (mean a) ^ 2)) := by + (∑ i, ((a i - x) ^ 2)) = n * ((x - mean a) ^ 2 + (mean (a ^ 2) - (mean a) ^ 2)) := by calc -- 1) Σ (aᵢ - x)² = Σ (aᵢ² - 2aᵢx + x²) _ = ∑ i, ((a i) ^ 2 - 2 * (a i) * x + (x ^ 2)) := by @@ -40,7 +40,7 @@ lemma leastSquares_alt_objFun {n : ℕ} (hn : 0 < n) (a : Fin n → ℝ) (x : /-- Key result about least squares: `x* = mean a`. -/ lemma leastSquares_optimal_eq_mean {n : ℕ} (hn : 0 < n) (a : Fin n → ℝ) (x : ℝ) - (h : (leastSquares a).optimal x) : x = mean a := by + (h : (leastSquares a).optimal x) : x = mean a := by simp [optimal, feasible, leastSquares] at h replace h : ∀ y, (x - mean a) ^ 2 ≤ (y - mean a) ^ 2 := by intros y @@ -60,7 +60,7 @@ def leastSquaresVec {n : ℕ} (a : Fin n → ℝ) := /-- Same as `leastSquares_optimal_eq_mean` in vector notation. -/ lemma leastSquaresVec_optimal_eq_mean {n : ℕ} (hn : 0 < n) (a : Fin n → ℝ) (x : ℝ) - (h : (leastSquaresVec a).optimal x) : x = mean a := by + (h : (leastSquaresVec a).optimal x) : x = mean a := by apply leastSquares_optimal_eq_mean hn a simp [leastSquaresVec, leastSquares, optimal, feasible] at h ⊢ intros y @@ -96,7 +96,7 @@ equivalence* eqv/fittingSphereT (n m : ℕ) (x : Fin m → Fin n → ℝ) : fitt equivalence_step => apply ChangeOfVariables.toEquivalence (fun (ct : (Fin n → ℝ) × ℝ) => (ct.1, sqrt (ct.2 + ‖ct.1‖ ^ 2))) - . rintro _ h; exact le_of_lt h + · rintro _ h; exact le_of_lt h rename_vars [c, t] -- Clean up. conv_constr h₁ => dsimp @@ -128,21 +128,21 @@ lemma vec_squared_norm_error_eq_zero_iff {n m : ℕ} (a : Fin m → Fin n → simp [rpow_two] rw [sum_eq_zero_iff_of_nonneg (fun _ _ => sq_nonneg _)] constructor - . intros h i + · intros h i have hi := h i (by simp) rwa [sq_eq_zero_iff, norm_eq_zero, sub_eq_zero] at hi - . intros h i _ + · intros h i _ rw [sq_eq_zero_iff, norm_eq_zero, sub_eq_zero] exact h i /-- This tells us that solving the relaxed problem is sufficient for optimal points if the solution is non-trivial. -/ lemma optimal_convex_implies_optimal_t (hm : 0 < m) (c : Fin n → ℝ) (t : ℝ) - (h_nontrivial : x ≠ Vec.const m c) - (h_opt : (fittingSphereConvex n m x).optimal (c, t)) : (fittingSphereT n m x).optimal (c, t) := by + (h_nontrivial : x ≠ Vec.const m c) (h_opt : (fittingSphereConvex n m x).optimal (c, t)) : + (fittingSphereT n m x).optimal (c, t) := by simp [fittingSphereT, fittingSphereConvex, optimal, feasible] at h_opt ⊢ constructor - . let a := Vec.norm x ^ 2 - 2 * mulVec x c + · let a := Vec.norm x ^ 2 - 2 * mulVec x c have h_ls : optimal (leastSquaresVec a) t := by refine ⟨trivial, ?_⟩ intros y _ @@ -177,12 +177,12 @@ lemma optimal_convex_implies_optimal_t (hm : 0 < m) (c : Fin n → ℝ) (t : ℝ exfalso rw [h_t_add_c2_eq, zero_eq_mul] at h_t_add_c2_eq_zero rcases h_t_add_c2_eq_zero with (hc | h_sum_eq_zero) - . simp at hc; linarith + · simp at hc; linarith rw [vec_squared_norm_error_eq_zero_iff] at h_sum_eq_zero apply h_nontrivial funext i exact h_sum_eq_zero i - . intros c' x' _ + · intros c' x' _ exact h_opt c' x' /-- We express the nontriviality condition only in terms of `x` so that it can be checked. -/ diff --git a/CvxLean/Examples/TrussDesign.lean b/CvxLean/Examples/TrussDesign.lean index fbbda49f..c573718e 100644 --- a/CvxLean/Examples/TrussDesign.lean +++ b/CvxLean/Examples/TrussDesign.lean @@ -63,7 +63,7 @@ equivalence* eqv₁/trussDesignGP (hmin hmax wmin wmax Rmax σ F₁ F₂ : ℝ) apply ChangeOfVariables.toEquivalence (fun (hwrA : ℝ × ℝ × ℝ × ℝ) => (hwrA.1, hwrA.2.1, hwrA.2.2.1, sqrt (hwrA.2.2.2 / (2 * π) + hwrA.2.2.1 ^ 2))) - . rintro ⟨h, w, r, R⟩ ⟨c_r, _, _, _, _, _, _, c_R_lb, _⟩; dsimp at * + · rintro ⟨h, w, r, R⟩ ⟨c_r, _, _, _, _, _, _, c_R_lb, _⟩; dsimp at * simp [ChangeOfVariables.condition] have h_r_le : r ≤ 1.1 * r := (le_mul_iff_one_le_left c_r).mpr (by norm_num) exact le_trans (le_of_lt c_r) (le_trans h_r_le c_R_lb) @@ -122,12 +122,12 @@ equivalence* eqv₂/trussDesignConvex (hmin hmax : ℝ) (hmin_pos : 0 < hmin) apply ChangeOfVariables.toEquivalence (fun (hwrA : ℝ × ℝ × ℝ × ℝ) => (exp hwrA.1, exp hwrA.2.1, exp hwrA.2.2.1, exp hwrA.2.2.2)) - . rintro ⟨h, w, r, A⟩ ⟨c_r, _, _, c_hmin, _, c_wmin, _, c_A_lb, _⟩; dsimp at * + · rintro ⟨h, w, r, A⟩ ⟨c_r, _, _, c_hmin, _, c_wmin, _, c_A_lb, _⟩; dsimp at * split_ands - . exact lt_of_lt_of_le hmin_pos c_hmin - . exact lt_of_lt_of_le wmin_pos c_wmin - . exact c_r - . have h_A_div_2π_pos : 0 < A / (2 * π) := lt_of_lt_of_le (by positivity) c_A_lb + · exact lt_of_lt_of_le hmin_pos c_hmin + · exact lt_of_lt_of_le wmin_pos c_wmin + · exact c_r + · have h_A_div_2π_pos : 0 < A / (2 * π) := lt_of_lt_of_le (by positivity) c_A_lb rwa [lt_div_iff (by positivity), zero_mul] at h_A_div_2π_pos conv_opt => dsimp rename_vars [h', w', r', A'] diff --git a/CvxLean/Examples/VehicleSpeedScheduling.lean b/CvxLean/Examples/VehicleSpeedScheduling.lean index 463c7271..51d96420 100644 --- a/CvxLean/Examples/VehicleSpeedScheduling.lean +++ b/CvxLean/Examples/VehicleSpeedScheduling.lean @@ -49,8 +49,8 @@ private lemma simp_vec_fraction (h_d_pos : StrongLT 0 d) (s : Fin n → ℝ) (i private lemma fold_partial_sum [hn : Fact (0 < n)] (t : Fin n → ℝ) (i : Fin n) : ∑ j in [[0, i]], t j = Vec.cumsum t i := by simp [Vec.cumsum]; split_ifs - . rfl - . linarith [hn.out] + · rfl + · linarith [hn.out] equivalence* eqv₁/vehSpeedSchedConvex (n : ℕ) (d : Fin n → ℝ) (τmin τmax smin smax : Fin n → ℝ) (F : ℝ → ℝ) (h_n_pos : 0 < n) (h_d_pos : StrongLT 0 d) @@ -62,7 +62,7 @@ equivalence* eqv₁/vehSpeedSchedConvex (n : ℕ) (d : Fin n → ℝ) -- TODO: This can be done by change of variables by detecting that the variable is a vector. equivalence_step => apply ChangeOfVariables.toEquivalence (fun t => d / t) - . rintro s ⟨c_smin, _⟩ i; split_ands <;> linarith [h_smin_pos i, c_smin i, h_d_pos i] + · rintro s ⟨c_smin, _⟩ i; split_ands <;> linarith [h_smin_pos i, c_smin i, h_d_pos i] rename_vars [t] -- Clean up divisions introduced by the change of variables. conv_obj => simp only [Pi.div_apply, simp_vec_fraction d h_d_pos] @@ -111,7 +111,7 @@ equivalence* eqv₂/vehSpeedSchedQuadratic (n : ℕ) (d : Fin n → ℝ) -- Add constraint to tell the system that `t` is positive. equivalence_step => apply Equivalence.add_constraint (cs' := fun t => StrongLT 0 t) - . rintro t ⟨c_smin, _⟩ i + · rintro t ⟨c_smin, _⟩ i exact t_pos_of_c_smin t c_smin i rename_vars [t] rename_constrs [c_t, c_smin, c_smax, c_τmin, c_τmax] @@ -142,26 +142,23 @@ equivalence* eqv₂/vehSpeedSchedQuadratic (n : ℕ) (d : Fin n → ℝ) -- Now, let's solve a concrete instance of the problem: -- https://github.com/cvxgrp/cvxbook_additional_exercises/blob/main/python/veh_speed_sched_data.py -set_option maxRecDepth 1000000 -set_option maxHeartbeats 1000000 - -@[optimization_param] +@[optimization_param, reducible] def nₚ : ℕ := 10 lemma nₚ_pos : 0 < nₚ := by unfold nₚ; norm_num -@[optimization_param] +@[optimization_param, reducible] def dₚ : Fin nₚ → ℝ := ![1.9501, 1.2311, 1.6068, 1.4860, 1.8913, 1.7621, 1.4565, 1.0185, 1.8214, 1.4447] lemma dₚ_pos : StrongLT 0 dₚ := by intro i; fin_cases i <;> (simp [dₚ]; norm_num) -@[optimization_param] +@[optimization_param, reducible] def τminₚ : Fin nₚ → ℝ := ![1.0809, 2.7265, 3.5118, 5.3038, 5.4516, 7.1648, 9.2674, 12.1543, 14.4058, 16.6258] -@[optimization_param] +@[optimization_param, reducible] def τmaxₚ : Fin nₚ → ℝ := ![4.6528, 6.5147, 7.5178, 9.7478, 9.0641, 10.3891, 13.1540, 16.0878, 17.4352, 20.9539] @@ -169,7 +166,7 @@ def τmaxₚ : Fin nₚ → ℝ := def sminₚ : Fin nₚ → ℝ := ![0.7828, 0.6235, 0.7155, 0.5340, 0.6329, 0.4259, 0.7798, 0.9604, 0.7298, 0.8405] -@[optimization_param] +@[optimization_param, reducible] def smaxₚ : Fin nₚ → ℝ := ![1.9624, 1.6036, 1.6439, 1.5641, 1.7194, 1.9090, 1.3193, 1.3366, 1.9470, 2.8803] diff --git a/CvxLean/Lib/Cones/All.lean b/CvxLean/Lib/Cones/All.lean index 642b1ef6..f5946092 100644 --- a/CvxLean/Lib/Cones/All.lean +++ b/CvxLean/Lib/Cones/All.lean @@ -1,5 +1,5 @@ import CvxLean.Lib.Cones.ExpCone import CvxLean.Lib.Cones.SOCone -import CvxLean.Lib.Cones.PosOrthCone +import CvxLean.Lib.Cones.NonnegOrthCone import CvxLean.Lib.Cones.ZeroCone import CvxLean.Lib.Cones.PSDCone diff --git a/CvxLean/Lib/Cones/ExpCone.lean b/CvxLean/Lib/Cones/ExpCone.lean index d14fd6b3..e6d2e245 100644 --- a/CvxLean/Lib/Cones/ExpCone.lean +++ b/CvxLean/Lib/Cones/ExpCone.lean @@ -1,6 +1,8 @@ import Mathlib.Data.Complex.Exponential /-! +Exponential cones. + We follow the MOSEK modeling cookbook: https://docs.mosek.com/modeling-cookbook/expo.html -/ diff --git a/CvxLean/Lib/Cones/NonnegOrthCone.lean b/CvxLean/Lib/Cones/NonnegOrthCone.lean new file mode 100644 index 00000000..20253dda --- /dev/null +++ b/CvxLean/Lib/Cones/NonnegOrthCone.lean @@ -0,0 +1,23 @@ +import CvxLean.Lib.Math.Data.Real +import CvxLean.Lib.Math.Data.Matrix + +/-! +Nonnegative orthant cones. +-/ + +namespace Real + +/-- The nonnegative orthant + `ℝ₊ := { x | 0 ≤ x } ⊆ ℝ`. -/ +def nonnegOrthCone (x : ℝ) : Prop := + 0 ≤ x + +/-- The `n`-dimensional nonnegative orthant `ℝ₊ⁿ`. -/ +def Vec.nonnegOrthCone {n} [Fintype n] (x : n → ℝ) : Prop := + ∀ i, Real.nonnegOrthCone (x i) + +/-- The `n×m`-dimensional nonnegative orthant `ℝ₊ⁿˣᵐ`. -/ +def Matrix.nonnegOrthCone {n m} [Fintype n] [Fintype m] (M : Matrix n m ℝ) : Prop := + ∀ i j, Real.nonnegOrthCone (M i j) + +end Real diff --git a/CvxLean/Lib/Cones/PSDCone.lean b/CvxLean/Lib/Cones/PSDCone.lean index b1668c39..78b2f9f7 100644 --- a/CvxLean/Lib/Cones/PSDCone.lean +++ b/CvxLean/Lib/Cones/PSDCone.lean @@ -1,6 +1,10 @@ import Mathlib.Data.Real.Basic import Mathlib.LinearAlgebra.Matrix.PosDef +/-! +Positive semidefinite cone. +-/ + namespace Real /-- The cone of `n×n` positive semidefinite matrices diff --git a/CvxLean/Lib/Cones/PosOrthCone.lean b/CvxLean/Lib/Cones/PosOrthCone.lean deleted file mode 100644 index 6f336abf..00000000 --- a/CvxLean/Lib/Cones/PosOrthCone.lean +++ /dev/null @@ -1,20 +0,0 @@ -import CvxLean.Lib.Math.Data.Real -import CvxLean.Lib.Math.Data.Matrix - -namespace Real - -/-- The positive (actually, nonnegative) orthant - `ℝ₊ := { x | 0 ≤ x } ⊆ ℝ`. -/ -def posOrthCone (x : ℝ) : Prop := - 0 ≤ x - -/-- The `n`-dimensional positive orthant `ℝ₊ⁿ`. -/ -def Vec.posOrthCone {n} [Fintype n] (x : n → ℝ) : Prop := - ∀ i, Real.posOrthCone (x i) - -/-- The `n×m`-dimensional positive orthant `ℝ₊ⁿˣᵐ`. -/ -def Matrix.posOrthCone {n m} [Fintype n] [Fintype m] (M : Matrix n m ℝ) : - Prop := - ∀ i j, Real.posOrthCone (M i j) - -end Real diff --git a/CvxLean/Lib/Cones/SOCone.lean b/CvxLean/Lib/Cones/SOCone.lean index 592faffd..e23db4e2 100644 --- a/CvxLean/Lib/Cones/SOCone.lean +++ b/CvxLean/Lib/Cones/SOCone.lean @@ -5,6 +5,8 @@ import CvxLean.Lib.Math.Data.Real import CvxLean.Lib.Math.Data.Vec /-! +Second-order cones. + We follow the MOSEK modeling cookbook: https://docs.mosek.com/modeling-cookbook/cqo.html -/ @@ -53,8 +55,8 @@ lemma rotateSoCone_rotatedSoCone {n : ℕ} {t : ℝ} {x : Fin n.succ → ℝ} (h have hrw : (t + x 0) * (t - x 0) = t ^ 2 - x 0 ^ 2 := by norm_cast; ring rw [hrw, le_sub_iff_add_le, add_comm] unfold soCone at h; norm_cast at h ⊢ - rw [←Fin.sum_univ_succ (f := fun i => (x i) ^ 2)] - rw [←sqrt_le_left ht] + rw [← Fin.sum_univ_succ (f := fun i => (x i) ^ 2)] + rw [← sqrt_le_left ht] exact h } { simp [le_div_iff]; linarith } { simp [le_div_iff]; linarith } @@ -72,7 +74,7 @@ lemma unrotateSoCone_soCone {n : ℕ} {v w : ℝ} {x : Fin n → ℝ} (h : rotat { simp [le_div_iff]; linarith } { rw [Fin.sum_univ_succ] simp [Matrix.vecCons] - rw [add_comm, ←le_sub_iff_add_le] + rw [add_comm, ← le_sub_iff_add_le] field_simp have hrw : ((v + w) ^ 2 - (v - w) ^ 2) / 2 = v * w * 2 := by norm_cast; ring norm_cast at hrw h @@ -86,26 +88,25 @@ section Lemmas `x, y ≥ 0` and `z ∈ ℝ`, `((x + y), (x - y, 2z)ᵀ) ∈ 𝒬ⁿ⁺¹ ↔ z ^ 2 ≤ xy`. -/ lemma soCone_add_sub_two_mul_of_nonneg {x y : ℝ} (z : ℝ) - (hx : 0 ≤ x) (hy : 0 ≤ y) : - soCone (x + y) ![x - y, 2 * z] ↔ z ^ (2 : ℝ) ≤ x * y := by + (hx : 0 ≤ x) (hy : 0 ≤ y) : soCone (x + y) ![x - y, 2 * z] ↔ z ^ (2 : ℝ) ≤ x * y := by have hxy := add_nonneg hx hy - conv => lhs; unfold soCone; simp [sqrt_le_left hxy, ←le_sub_iff_add_le'] + conv => lhs; unfold soCone; simp [sqrt_le_left hxy, ← le_sub_iff_add_le'] ring_nf; simp /-- Same as `soCone_add_sub_two_mul_of_nonneg` with `z = 1`. -/ lemma soCone_add_sub_two_of_nonneg {x y : ℝ} (hx : 0 ≤ x) (hy : 0 ≤ y) : - soCone (x + y) ![x - y, 2] ↔ 1 ≤ x * y := by + soCone (x + y) ![x - y, 2] ↔ 1 ≤ x * y := by have h := soCone_add_sub_two_mul_of_nonneg 1 hx hy rw [mul_one, one_rpow] at h exact h /-- Same as `soCone_add_sub_two_mul_of_nonneg` replacing `y` by `-y`. -/ lemma soCone_sub_add_two_mul_of_nonneg {x y : ℝ} (z : ℝ) : - soCone (x - y) ![x + y, 2 * z] ↔ y ≤ x ∧ z ^ (2 : ℝ) ≤ -(x * y) := by - conv => lhs; unfold soCone; simp [sqrt_le_iff, ←le_sub_iff_add_le'] + soCone (x - y) ![x + y, 2 * z] ↔ y ≤ x ∧ z ^ (2 : ℝ) ≤ -(x * y) := by + conv => lhs; unfold soCone; simp [sqrt_le_iff, ← le_sub_iff_add_le'] apply Iff.and - . rfl - . ring_nf!; rw [←neg_mul, ←div_le_iff (by norm_num)]; simp + · rfl + · ring_nf!; rw [← neg_mul, ← div_le_iff (by norm_num)]; simp open Real Matrix diff --git a/CvxLean/Lib/Cones/ZeroCone.lean b/CvxLean/Lib/Cones/ZeroCone.lean index 1c81e566..400b396b 100644 --- a/CvxLean/Lib/Cones/ZeroCone.lean +++ b/CvxLean/Lib/Cones/ZeroCone.lean @@ -1,6 +1,10 @@ import Mathlib.Data.Real.Basic import CvxLean.Lib.Math.Data.Matrix +/-! +Zero cones. +-/ + namespace Real /-- The zero cone `{0}`. -/ diff --git a/CvxLean/Lib/Math/Analysis/InnerProductSpace/GramSchmidtOrtho.lean b/CvxLean/Lib/Math/Analysis/InnerProductSpace/GramSchmidtOrtho.lean index 5a90648c..a2384f45 100644 --- a/CvxLean/Lib/Math/Analysis/InnerProductSpace/GramSchmidtOrtho.lean +++ b/CvxLean/Lib/Math/Analysis/InnerProductSpace/GramSchmidtOrtho.lean @@ -1,5 +1,9 @@ import Mathlib.Analysis.InnerProductSpace.GramSchmidtOrtho +/-! +The Gram-Schmidt algorithm respects basis vectors. +-/ + section GramSchmidt variable (𝕜 : Type _) {E : Type _} [IsROrC 𝕜] diff --git a/CvxLean/Lib/Math/Analysis/InnerProductSpace/Spectrum.lean b/CvxLean/Lib/Math/Analysis/InnerProductSpace/Spectrum.lean index 6d842645..64e368ab 100644 --- a/CvxLean/Lib/Math/Analysis/InnerProductSpace/Spectrum.lean +++ b/CvxLean/Lib/Math/Analysis/InnerProductSpace/Spectrum.lean @@ -1,5 +1,9 @@ import Mathlib.LinearAlgebra.Matrix.Spectrum +/-! +Version of the spectral theorem. +-/ + namespace LinearMap variable {𝕜 : Type _} [IsROrC 𝕜] [DecidableEq 𝕜] @@ -8,10 +12,9 @@ variable [FiniteDimensional 𝕜 E] variable {n : ℕ} (hn : FiniteDimensional.finrank 𝕜 E = n) variable {T : E →ₗ[𝕜] E} -/-- *Diagonalization theorem*, *spectral theorem*; version 3: A self-adjoint -operator `T` on a finite-dimensional inner product space `E` acts diagonally on -the identification of `E` with Euclidean space induced by an orthonormal basis -of eigenvectors of `T`. -/ +/-- *Diagonalization theorem*, *spectral theorem*; version 3: A self-adjoint operator `T` on a +finite-dimensional inner product space `E` acts diagonally on the identification of `E` with +Euclidean space induced by an orthonormal basis of eigenvectors of `T`. -/ lemma spectral_theorem' (v : E) (i : Fin n) (xs : OrthonormalBasis (Fin n) 𝕜 E) (as : Fin n → ℝ) (hxs : ∀ j, Module.End.HasEigenvector T (as j) (xs j)) : diff --git a/CvxLean/Lib/Math/Data/Array.lean b/CvxLean/Lib/Math/Data/Array.lean index b3cdba82..a117e6ef 100644 --- a/CvxLean/Lib/Math/Data/Array.lean +++ b/CvxLean/Lib/Math/Data/Array.lean @@ -1,6 +1,10 @@ import Mathlib.Data.List.Basic import Mathlib.Algebra.Group.Defs +/-! +Some extra operations on arrays. +-/ + open Lean variable {α β γ} diff --git a/CvxLean/Lib/Math/Data/List.lean b/CvxLean/Lib/Math/Data/List.lean index ddb8100d..d74fe1c3 100644 --- a/CvxLean/Lib/Math/Data/List.lean +++ b/CvxLean/Lib/Math/Data/List.lean @@ -1,5 +1,9 @@ import Mathlib.Algebra.Ring.Basic +/-! +Some extra operations on lists. +-/ + namespace List variable {α} diff --git a/CvxLean/Lib/Math/Data/Matrix.lean b/CvxLean/Lib/Math/Data/Matrix.lean index 6db3105a..e026051a 100644 --- a/CvxLean/Lib/Math/Data/Matrix.lean +++ b/CvxLean/Lib/Math/Data/Matrix.lean @@ -28,15 +28,15 @@ def abs (A : Matrix m n ℝ) : Matrix m n ℝ := fun i j => |A i j| theorem vecCons_zero_zero {n} [Zero α] : vecCons (0 : α) (0 : Fin n → α) = 0 := by - ext i ; refine' Fin.cases _ _ i <;> simp [vecCons] + ext i; refine' Fin.cases _ _ i <;> simp [vecCons] theorem smul_vecCons {n} [Zero α] [SMulZeroClass ℝ α] (x : ℝ) (y : α) (v : Fin n → α) : x • vecCons y v = vecCons (x • y) (x • v) := by - ext i ; refine' Fin.cases _ _ i <;> simp [vecCons] + ext i; refine' Fin.cases _ _ i <;> simp [vecCons] theorem add_vecCons {n} [Zero α] [SMulZeroClass ℝ α] [Add α] (x : α) (v : Fin n → α) (y : α) (w : Fin n → α) : vecCons x v + vecCons y w = vecCons (x + y) (v + w) := by - ext i ; refine' Fin.cases _ _ i <;> simp [vecCons] + ext i; refine' Fin.cases _ _ i <;> simp [vecCons] open BigOperators diff --git a/CvxLean/Lib/Math/Data/Real.lean b/CvxLean/Lib/Math/Data/Real.lean index 25963dc1..f14876e2 100644 --- a/CvxLean/Lib/Math/Data/Real.lean +++ b/CvxLean/Lib/Math/Data/Real.lean @@ -84,7 +84,7 @@ lemma log_eq_log {x y : ℝ} (hx : 0 < x) (hy : 0 < y) : x = y ↔ log x = log y lemma div_pow_eq_mul_pow_neg {a b c : ℝ} (hb : 0 ≤ b) : a / (b ^ c) = a * b ^ (-c) := by - rw [div_eq_mul_inv, ←rpow_neg hb] + rw [div_eq_mul_inv, ← rpow_neg hb] lemma one_div_eq_pow_neg_one {a : ℝ} (ha : 0 < a) : 1 / a = a ^ (-1 : ℝ) := by rw [rpow_neg (le_of_lt ha), rpow_one, div_eq_mul_inv, one_mul] diff --git a/CvxLean/Lib/Math/Data/Vec.lean b/CvxLean/Lib/Math/Data/Vec.lean index b03958de..53fd8095 100644 --- a/CvxLean/Lib/Math/Data/Vec.lean +++ b/CvxLean/Lib/Math/Data/Vec.lean @@ -103,16 +103,16 @@ variable {a b c : m → ℝ} lemma div_le_iff (hb : StrongLT 0 b) : a / b ≤ c ↔ a ≤ c * b := by constructor - . intro h i; have hi := h i; simp at hi; + · intro h i; have hi := h i; simp at hi; rw [_root_.div_le_iff (hb i)] at hi; exact hi - . intro h i; have hi := h i; simp at hi; + · intro h i; have hi := h i; simp at hi; dsimp; rw [_root_.div_le_iff (hb i)]; exact hi lemma le_div_iff (hc : StrongLT 0 c) : a ≤ b / c ↔ a * c ≤ b := by constructor - . intro h i; have hi := h i; simp at hi; + · intro h i; have hi := h i; simp at hi; rw [_root_.le_div_iff (hc i)] at hi; exact hi - . intro h i; have hi := h i; simp at hi; + · intro h i; have hi := h i; simp at hi; dsimp; rw [_root_.le_div_iff (hc i)]; exact hi end RealLemmas diff --git a/CvxLean/Lib/Math/LinearAlgebra/Eigenspace.lean b/CvxLean/Lib/Math/LinearAlgebra/Eigenspace.lean index c155455d..89af034b 100644 --- a/CvxLean/Lib/Math/LinearAlgebra/Eigenspace.lean +++ b/CvxLean/Lib/Math/LinearAlgebra/Eigenspace.lean @@ -1,5 +1,9 @@ import Mathlib.LinearAlgebra.Eigenspace.Basic +/-! +Results about the eigenspaces of the sum of two endomorphisms. +-/ + universe u v w namespace Module @@ -17,7 +21,7 @@ lemma eigenspace_add {f g : End R M} {a b : R} : LinearMap.sub_apply, algebraMap_end_apply] at hf hg simp only [eigenspace, map_add, LinearMap.mem_ker, LinearMap.sub_apply, LinearMap.add_apply, algebraMap_end_apply] - rw [←add_sub, add_comm (a • x), ←sub_sub, hg, add_sub, add_zero, hf] + rw [← add_sub, add_comm (a • x), ← sub_sub, hg, add_sub, add_zero, hf] lemma eigenspace_one : eigenspace (1 : End R M) 1 = ⊤ := by apply eq_top_iff.2 diff --git a/CvxLean/Lib/Math/LinearAlgebra/Matrix/Spectrum.lean b/CvxLean/Lib/Math/LinearAlgebra/Matrix/Spectrum.lean index 67a09224..91fe4d97 100644 --- a/CvxLean/Lib/Math/LinearAlgebra/Matrix/Spectrum.lean +++ b/CvxLean/Lib/Math/LinearAlgebra/Matrix/Spectrum.lean @@ -1,6 +1,10 @@ import Mathlib.LinearAlgebra.Matrix.Spectrum import CvxLean.Lib.Math.Analysis.InnerProductSpace.Spectrum +/-! +Version of the spectral theorem for matrices. +-/ + namespace Matrix variable {𝕜 : Type _} [IsROrC 𝕜] [DecidableEq 𝕜] diff --git a/CvxLean/Lib/Math/LinearAlgebra/Matrix/ToUpperTri.lean b/CvxLean/Lib/Math/LinearAlgebra/Matrix/ToUpperTri.lean index 0b83f404..4bd24d42 100644 --- a/CvxLean/Lib/Math/LinearAlgebra/Matrix/ToUpperTri.lean +++ b/CvxLean/Lib/Math/LinearAlgebra/Matrix/ToUpperTri.lean @@ -1,14 +1,14 @@ import Mathlib.Data.Real.Basic import CvxLean.Lib.Math.LinearAlgebra.Matrix.Triangular -namespace Matrix - /-! # Conversion to upper triangular matrices Defines `toUpperTri` and proves some basic lemmas about it. -/ +namespace Matrix + open Matrix variable {n : Type} [Fintype n] [LinearOrder n] [LocallyFiniteOrderBot n] diff --git a/CvxLean/Lib/Math/LinearAlgebra/Matrix/Triangular.lean b/CvxLean/Lib/Math/LinearAlgebra/Matrix/Triangular.lean index 256ac6d5..a0711a5c 100644 --- a/CvxLean/Lib/Math/LinearAlgebra/Matrix/Triangular.lean +++ b/CvxLean/Lib/Math/LinearAlgebra/Matrix/Triangular.lean @@ -1,7 +1,7 @@ import Mathlib.LinearAlgebra.Matrix.Block import CvxLean.Lib.Math.LinearAlgebra.Matrix.Block -/- +/-! # Triangular Matrices This file defines upper and lower triangular matrices. The definitions are based on diff --git a/CvxLean/Lib/Math/Subadditivity.lean b/CvxLean/Lib/Math/Subadditivity.lean index 6d31a33a..b63da3c8 100644 --- a/CvxLean/Lib/Math/Subadditivity.lean +++ b/CvxLean/Lib/Math/Subadditivity.lean @@ -7,12 +7,17 @@ import CvxLean.Lib.Math.LinearAlgebra.Matrix.PosDef import CvxLean.Lib.Math.LinearAlgebra.Matrix.Spectrum import CvxLean.Lib.Math.LinearAlgebra.Eigenspace +/-! +If `A` and `B` are positive semi-definite, then `det A + det B ≤ det (A + B)`. This is needed to +define the log-det atom. +-/ + namespace Finset open BigOperators lemma one_add_prod_le_prod_one_add {n : Type _} [Fintype n] [Nonempty n] - (f : n → ℝ) (hf : ∀ i, 0 ≤ f i) : 1 + (∏ i, f i) ≤ ∏ i, (1 + f i) := by + (f : n → ℝ) (hf : ∀ i, 0 ≤ f i) : 1 + (∏ i, f i) ≤ ∏ i, (1 + f i) := by classical calc 1 + (∏ i, f i) = (∏ _a : n, 1 : ℝ) * ∏ a : n in univ \ univ, f a @@ -38,56 +43,56 @@ namespace IsHermitian variable {𝕜 : Type _} [DecidableEq 𝕜] [IsROrC 𝕜] {A : Matrix n n 𝕜} (hA : A.IsHermitian) -lemma eigenvectorMatrix_inv_mul : - hA.eigenvectorMatrixInv * hA.eigenvectorMatrix = 1 := -by apply Basis.toMatrix_mul_toMatrix_flip +lemma eigenvectorMatrix_inv_mul : hA.eigenvectorMatrixInv * hA.eigenvectorMatrix = 1 := by + apply Basis.toMatrix_mul_toMatrix_flip -- NOTE: There is a `spectral_theorem'`. theorem spectral_theorem'' : - hA.eigenvectorMatrix * diagonal (IsROrC.ofReal ∘ hA.eigenvalues) * hA.eigenvectorMatrixᴴ = A := by + hA.eigenvectorMatrix * diagonal (IsROrC.ofReal ∘ hA.eigenvalues) * hA.eigenvectorMatrixᴴ = + A := by rw [conjTranspose_eigenvectorMatrix, Matrix.mul_assoc, ← spectral_theorem, - ←Matrix.mul_assoc, eigenvectorMatrix_mul_inv, Matrix.one_mul] + ← Matrix.mul_assoc, eigenvectorMatrix_mul_inv, Matrix.one_mul] end IsHermitian noncomputable def IsHermitian.sqrt {A : Matrix n n ℝ} (hA : A.IsHermitian) : Matrix n n ℝ := -hA.eigenvectorMatrix * Matrix.diagonal (fun i => (hA.eigenvalues i).sqrt) * hA.eigenvectorMatrixᵀ + hA.eigenvectorMatrix * Matrix.diagonal (fun i => (hA.eigenvalues i).sqrt) * hA.eigenvectorMatrixᵀ lemma conjTranspose_eq_transpose {m n : Type _} {A : Matrix m n ℝ} : Aᴴ = Aᵀ := rfl @[simp] lemma PosSemidef.sqrt_mul_sqrt {A : Matrix n n ℝ} (hA : A.PosSemidef) : - hA.1.sqrt * hA.1.sqrt = A := -calc - hA.1.sqrt * hA.1.sqrt = - hA.1.eigenvectorMatrix * (Matrix.diagonal (fun i => (hA.1.eigenvalues i).sqrt) - * (hA.1.eigenvectorMatrixᵀ * hA.1.eigenvectorMatrix) - * Matrix.diagonal (fun i => (hA.1.eigenvalues i).sqrt)) * hA.1.eigenvectorMatrixᵀ := by - simp [IsHermitian.sqrt, Matrix.mul_assoc] - _ = A := by - rw [←conjTranspose_eq_transpose, hA.1.conjTranspose_eigenvectorMatrix, - hA.1.eigenvectorMatrix_inv_mul, Matrix.mul_one, diagonal_mul_diagonal, - ← hA.1.conjTranspose_eigenvectorMatrix] - convert hA.1.spectral_theorem'' - rw [←Real.sqrt_mul (hA.eigenvalues_nonneg _), Real.sqrt_mul_self (hA.eigenvalues_nonneg _)] - simp + hA.1.sqrt * hA.1.sqrt = A := + calc + hA.1.sqrt * hA.1.sqrt = + hA.1.eigenvectorMatrix * (Matrix.diagonal (fun i => (hA.1.eigenvalues i).sqrt) + * (hA.1.eigenvectorMatrixᵀ * hA.1.eigenvectorMatrix) + * Matrix.diagonal (fun i => (hA.1.eigenvalues i).sqrt)) * hA.1.eigenvectorMatrixᵀ := by + simp [IsHermitian.sqrt, Matrix.mul_assoc] + _ = A := by + rw [← conjTranspose_eq_transpose, hA.1.conjTranspose_eigenvectorMatrix, + hA.1.eigenvectorMatrix_inv_mul, Matrix.mul_one, diagonal_mul_diagonal, + ← hA.1.conjTranspose_eigenvectorMatrix] + convert hA.1.spectral_theorem'' + rw [← Real.sqrt_mul (hA.eigenvalues_nonneg _), Real.sqrt_mul_self (hA.eigenvalues_nonneg _)] + simp lemma PosSemidef.PosSemidef_sqrt {A : Matrix n n ℝ} (hA : A.PosSemidef) : - hA.1.sqrt.PosSemidef := -PosSemidef.conjTranspose_mul_mul _ _ - (PosSemidef_diagonal (fun i => Real.sqrt_nonneg (hA.1.eigenvalues i))) + hA.1.sqrt.PosSemidef := + PosSemidef.conjTranspose_mul_mul _ _ + (PosSemidef_diagonal (fun i => Real.sqrt_nonneg (hA.1.eigenvalues i))) lemma IsHermitian.one_add {A : Matrix n n ℝ} (hA : A.IsHermitian) : (1 + A).IsHermitian := by dsimp [IsHermitian]; rw [IsHermitian.add _ hA]; simp lemma IsHermitian.has_eigenvector_one_add {A : Matrix n n ℝ} (hA : A.IsHermitian) (i : n) : - Module.End.HasEigenvector (1 + Matrix.toLin' A) (1 + (hA.eigenvalues i)) ((hA.eigenvectorBasis) i) := -Module.End.has_eigenvector_add - (Module.End.has_eigenvector_one (hA.hasEigenvector_eigenvectorBasis i).2) - (hA.hasEigenvector_eigenvectorBasis i) + Module.End.HasEigenvector + (1 + Matrix.toLin' A) (1 + (hA.eigenvalues i)) ((hA.eigenvectorBasis) i) := + Module.End.has_eigenvector_add + (Module.End.has_eigenvector_one (hA.hasEigenvector_eigenvectorBasis i).2) + (hA.hasEigenvector_eigenvectorBasis i) -lemma PosDef.PosDef_sqrt {A : Matrix n n ℝ} (hA : A.PosDef) : - hA.1.sqrt.PosDef := by +lemma PosDef.PosDef_sqrt {A : Matrix n n ℝ} (hA : A.PosDef) : hA.1.sqrt.PosDef := by unfold IsHermitian.sqrt refine' PosDef.conjTranspose_mul_mul _ (hA.1.eigenvectorMatrixᵀ) @@ -97,11 +102,11 @@ lemma PosDef.PosDef_sqrt {A : Matrix n n ℝ} (hA : A.PosDef) : apply det_ne_zero_of_right_inverse hA.1.eigenvectorMatrix_mul_inv lemma PosSemidef.PosDef_iff_det_ne_zero [DecidableEq n] {M : Matrix n n ℝ} (hM : M.PosSemidef) : - M.PosDef ↔ M.det ≠ 0 := by + M.PosDef ↔ M.det ≠ 0 := by refine' ⟨PosDef.det_ne_zero, _⟩; intro hdet; refine' ⟨hM.1, _⟩ intros x hx apply lt_of_le_of_ne' (hM.2 x) - rw [←hM.sqrt_mul_sqrt, ←mulVec_mulVec, dotProduct_mulVec, ←transpose_transpose hM.1.sqrt, + rw [← hM.sqrt_mul_sqrt, ← mulVec_mulVec, dotProduct_mulVec, ← transpose_transpose hM.1.sqrt, vecMul_transpose, transpose_transpose, ← conjTranspose_eq_transpose, hM.PosSemidef_sqrt.1.eq] simp only [IsROrC.re_to_real, star, id] @@ -110,7 +115,7 @@ lemma PosSemidef.PosDef_iff_det_ne_zero [DecidableEq n] {M : Matrix n n ℝ} (hM have sqrtMdet0 : hM.1.sqrt.det = 0 := by refine' exists_mulVec_eq_zero_iff.1 ⟨x, hx, _⟩ rw [inner_self_eq_zero.1 hinner] - rw [←hM.sqrt_mul_sqrt, det_mul, sqrtMdet0, mul_zero] at hdet + rw [← hM.sqrt_mul_sqrt, det_mul, sqrtMdet0, mul_zero] at hdet apply hdet rfl /-- Subadditivity lemma for positive semidefinite matrices. This version assumes that one of the @@ -118,9 +123,8 @@ matrices is positive definite. See `det_add_det_le_det_add` for the more general The argument is taken from Andreas Thom's comment on mathoverflow: https://mathoverflow.net/questions/65424/determinant-of-sum-of-positive-definite-matrices. -/ -lemma det_add_det_le_det_add' [Nonempty n] (A B : Matrix n n ℝ) - (hA : A.PosDef) (hB : B.PosSemidef) : - A.det + B.det ≤ (A + B).det := by +lemma det_add_det_le_det_add' [Nonempty n] (A B : Matrix n n ℝ) (hA : A.PosDef) + (hB : B.PosSemidef) : A.det + B.det ≤ (A + B).det := by let sqrtA := hA.1.sqrt have isUnit_det_sqrtA := isUnit_iff_ne_zero.2 hA.PosDef_sqrt.det_ne_zero @@ -134,8 +138,8 @@ lemma det_add_det_le_det_add' [Nonempty n] (A B : Matrix n n ℝ) let μ := PosSemidef_ABA.1.eigenvalues calc A.det + B.det = A.det * (1 + (sqrtA⁻¹ * B * sqrtA⁻¹).det) := by - rw [det_mul, det_mul, mul_comm _ B.det, mul_assoc, ←det_mul, ←Matrix.mul_inv_rev, - hA.posSemidef.sqrt_mul_sqrt, mul_add, mul_one, mul_comm, mul_assoc, ←det_mul, + rw [det_mul, det_mul, mul_comm _ B.det, mul_assoc, ← det_mul, ← Matrix.mul_inv_rev, + hA.posSemidef.sqrt_mul_sqrt, mul_add, mul_one, mul_comm, mul_assoc, ← det_mul, nonsing_inv_mul _ (isUnit_iff_ne_zero.2 hA.det_ne_zero), det_one, mul_one] _ = A.det * (1 + ∏ i, μ i) := by rw [PosSemidef_ABA.1.det_eq_prod_eigenvalues] @@ -155,7 +159,7 @@ lemma det_add_det_le_det_add' [Nonempty n] (A B : Matrix n n ℝ) _ = (A + B).det := by rw [← det_mul, ← det_conj this (A + B)] apply congr_arg - rw [←hA.posSemidef.sqrt_mul_sqrt] + rw [← hA.posSemidef.sqrt_mul_sqrt] change sqrtA * sqrtA * (1 + sqrtA⁻¹ * B * sqrtA⁻¹) = sqrtA * (sqrtA * sqrtA + B) * sqrtA⁻¹ rw [Matrix.mul_add, Matrix.mul_one, Matrix.mul_add, Matrix.add_mul, Matrix.mul_assoc, Matrix.mul_assoc, Matrix.mul_assoc, Matrix.mul_assoc, @@ -164,9 +168,8 @@ lemma det_add_det_le_det_add' [Nonempty n] (A B : Matrix n n ℝ) hA.posSemidef.sqrt_mul_sqrt, Matrix.mul_assoc] /-- Subadditivity lemma for positive semidefinite matrices. -/ -lemma det_add_det_le_det_add [Nonempty n] (A B : Matrix n n ℝ) - (hA : A.PosSemidef) (hB : B.PosSemidef) : - A.det + B.det ≤ (A + B).det := by +lemma det_add_det_le_det_add [Nonempty n] (A B : Matrix n n ℝ) (hA : A.PosSemidef) + (hB : B.PosSemidef) : A.det + B.det ≤ (A + B).det := by by_cases hA' : A.det = 0 { by_cases hB' : B.det = 0 { simp [hA', hB'] diff --git a/CvxLean/Meta/Minimization.lean b/CvxLean/Meta/Minimization.lean index e4fd0173..d1f6e67d 100644 --- a/CvxLean/Meta/Minimization.lean +++ b/CvxLean/Meta/Minimization.lean @@ -135,13 +135,14 @@ names of the new variables. -/ def withDomainLocalDecls {α} [Inhabited α] (domain : Expr) (p : Expr) (x : Array Expr → Array Expr → m α) : m α := do let pr := (← mkProjections domain p).toArray - withLetDecls (pr.map fun (n, ty, val) => (n, fun _ => return ty, fun _ => return val)) fun xs => do - let mut xs := xs - -- Use projections instead of variables named "_". - for i in [:pr.size] do - if pr[i]!.1 == `_ then - xs := xs.set! i pr[i]!.2.2 - x xs (pr.map (fun a => a.2.2)) + withLetDecls + (pr.map fun (n, ty, val) => (n, fun _ => return ty, fun _ => return val)) fun xs => do + let mut xs := xs + -- Use projections instead of variables named "_". + for i in [:pr.size] do + if pr[i]!.1 == `_ then + xs := xs.set! i pr[i]!.2.2 + x xs (pr.map (fun a => a.2.2)) /-- Decompose an expression into its `And`-connected components. -/ def decomposeAnd (e : Expr) : MetaM (List (Expr)) := do diff --git a/CvxLean/Meta/Util/Error.lean b/CvxLean/Meta/Util/Error.lean index c8a55a75..e7167c30 100644 --- a/CvxLean/Meta/Util/Error.lean +++ b/CvxLean/Meta/Util/Error.lean @@ -139,6 +139,14 @@ end PreDCP section DCP +/-- Throws an error coming from the `declare_atom` command. -/ +syntax "throwAtomDeclarationError " (interpolatedStr(term) <|> term) : term + +macro_rules + | `(throwAtomDeclarationError $msg:interpolatedStr) => + `(throwError ("`declare_atom` error: " ++ (m! $msg))) + | `(throwAtomDeclarationError $msg:term) => `(throwError ("`declare_atom` error: " ++ $msg)) + /-- Throws an error coming from the `dcp` tactic or any of the auxiliary functions. -/ syntax "throwDCPError " (interpolatedStr(term) <|> term) : term diff --git a/CvxLean/Meta/Util/Expr.lean b/CvxLean/Meta/Util/Expr.lean index 40c3d9f8..758aa7db 100644 --- a/CvxLean/Meta/Util/Expr.lean +++ b/CvxLean/Meta/Util/Expr.lean @@ -19,6 +19,25 @@ def mkAndIntro (xs : Array Expr) : MetaM Expr := do res ← mkAppM ``And.intro #[xs[xs.size - 2 - i]!, res] return res +/-- Similar to `mkAndIntro` but supporting empty list. +TODO: combine the two. -/ +def foldAndIntro (exprs : Array Expr) : MetaM Expr := do + if exprs.size == 0 then + return mkConst ``True.intro + let mut res := exprs[exprs.size - 1]! + for i in [1:exprs.size] do + res ← mkAppM ``And.intro #[exprs[exprs.size - 1 - i]!, res] + return res + +/-- Turn an array of expressions into a product of expressions. -/ +def foldProdMk (exprs : Array Expr) : MetaM Expr := do + if exprs.size == 0 then + throwError "cannot make a product of an empty list." + let mut res := exprs[exprs.size - 1]! + for i in [1:exprs.size] do + res ← mkAppM ``Prod.mk #[exprs[exprs.size - 1 - i]!, res] + return res + /-- Make existential type. -/ def mkExistsFVars (xs : Array Expr) (e : Expr) : MetaM Expr := do let mut res := e diff --git a/CvxLean/Meta/Util/Meta.lean b/CvxLean/Meta/Util/Meta.lean index d1a787aa..27babf53 100644 --- a/CvxLean/Meta/Util/Meta.lean +++ b/CvxLean/Meta/Util/Meta.lean @@ -22,7 +22,7 @@ def withLambdaBody (e : Expr) (x : (fvar : Expr) → (body : Expr) → MetaM α) /-- Add local declarations where the type constructor is trivial (non-dependant on the other declarations). -/ def withLocalDeclsDNondep [Inhabited α] (declInfos : Array (Lean.Name × Expr)) - (k : (xs : Array Expr) → m α) : m α := + (k : (xs : Array Expr) → m α) : m α := withLocalDeclsD (declInfos.map fun (n, t) => (n, fun _ => pure t)) k /-- Introduce let declarations into the local context. -/ @@ -34,7 +34,7 @@ where loop [Inhabited α] (acc : Array Expr) : m α := do if acc.size < declInfos.size then let (name, typeCtor, valCtor) := declInfos[acc.size]! - withLetDecl name (←typeCtor acc) (←valCtor acc) fun x => loop (acc.push x) + withLetDecl name (← typeCtor acc) (← valCtor acc) fun x => loop (acc.push x) else k acc diff --git a/CvxLean/Syntax/Minimization.lean b/CvxLean/Syntax/Minimization.lean index f25ba123..d5b5db07 100644 --- a/CvxLean/Syntax/Minimization.lean +++ b/CvxLean/Syntax/Minimization.lean @@ -50,7 +50,7 @@ open Lean.Elab Lean.Elab.Term Lean.Meta Lean.Parser.Term private def decomposeBracketedBinder : Syntax → TermElabM (Array (Syntax × Syntax)) := fun stx => match stx[0] with | `(bracketedBinderF|($ids* : $ty)) => return ids.map (·.raw, ty.raw) - | `(ident|$id) => return #[(id.raw, (←`(_)).raw)] + | `(ident|$id) => return #[(id.raw, (← `(_)).raw)] /-- Get the names and types of the variables after the `optimization` keyword. -/ def elabVars (idents : Array (TSyntax `CvxLean.Parser.minimizationVar)) : diff --git a/CvxLean/Tactic/Arith/PositivityExt.lean b/CvxLean/Tactic/Arith/PositivityExt.lean index 668a18eb..c2678f88 100644 --- a/CvxLean/Tactic/Arith/PositivityExt.lean +++ b/CvxLean/Tactic/Arith/PositivityExt.lean @@ -70,8 +70,7 @@ def evalOneDivSqSubOne : PositivityExt where eval {_ _α} zα pα e := do | _, _ => pure .none -lemma Real.one_sub_sq_nonneg_of_le_one {x : ℝ} : - 0 ≤ x → 0 ≤ 1 - x → 0 ≤ 1 - x ^ (2 : ℝ) := +lemma Real.one_sub_sq_nonneg_of_le_one {x : ℝ} : 0 ≤ x → 0 ≤ 1 - x → 0 ≤ 1 - x ^ (2 : ℝ) := fun h1 h2 => by have hx1 : x ≤ 1 := by linarith rw [le_sub_iff_add_le, zero_add] @@ -119,8 +118,7 @@ def evalExpSubOne : PositivityExt where eval {_ _α} zα pα e := do | _ => pure .none -lemma Real.one_sub_div_exp_pos_of_pos {x : ℝ} : - 0 < x → 0 < 1 - 1 / Real.exp x := +lemma Real.one_sub_div_exp_pos_of_pos {x : ℝ} : 0 < x → 0 < 1 - 1 / Real.exp x := fun h => by field_simp; positivity @[positivity (1 - (1 / (Real.exp (_ : ℝ))))] @@ -135,13 +133,13 @@ def evalOneSubDivExp : PositivityExt where eval {_ _α} zα pα e := do pure .none lemma Real.scaled_sq_diff_pos_of_pos {a x : ℝ} : - 0 < a - 1 → 0 < x → 0 < (a * x) ^ (2 : ℝ) - x ^ (2 : ℝ) := + 0 < a - 1 → 0 < x → 0 < (a * x) ^ (2 : ℝ) - x ^ (2 : ℝ) := fun h1 h2 => by have ha : 0 ≤ a := by linarith have ha1 : 1 < a := by linarith have hx : 0 ≤ x := by linarith have hx2 : 0 < x ^ (2 : ℝ) := by positivity - rw [sub_pos, Real.mul_rpow ha hx, ←div_lt_iff hx2, div_self (ne_of_gt hx2)] + rw [sub_pos, Real.mul_rpow ha hx, ← div_lt_iff hx2, div_self (ne_of_gt hx2)] simp [ha1, abs_eq_self.mpr ha] @[positivity (((_ : ℝ) * (_ : ℝ)) ^ (2 : ℝ)) - ((_ : ℝ) ^ (2 : ℝ))] diff --git a/CvxLean/Tactic/Basic/ChangeOfVariables.lean b/CvxLean/Tactic/Basic/ChangeOfVariables.lean index 88c3d55d..b95b5057 100644 --- a/CvxLean/Tactic/Basic/ChangeOfVariables.lean +++ b/CvxLean/Tactic/Basic/ChangeOfVariables.lean @@ -3,6 +3,7 @@ import CvxLean.Lib.Minimization import CvxLean.Lib.Equivalence import CvxLean.Meta.Minimization import CvxLean.Meta.Util.Expr +import CvxLean.Meta.Util.Error import CvxLean.Meta.TacticBuilder import CvxLean.Tactic.Arith.Arith import CvxLean.Tactic.Basic.RemoveTrivialConstrs @@ -169,7 +170,7 @@ def changeOfVariablesBuilder (newVarStx varToChangeStx : TSyntax `ident) -- Find change of variables location. let covIdx := vars.findIdx? (fun ⟨n, _⟩ => n == varToChange) if covIdx.isNone then - throwError "Variable {varToChange} not found in domain." + throwChangeOfVariablesError "variable {varToChange} not found in domain." let covIdx := covIdx.get! -- New domain. @@ -217,9 +218,9 @@ def changeOfVariablesBuilder (newVarStx varToChangeStx : TSyntax `ident) let toApply ← instantiateMVars toApply let gsAfterApply ← g.apply toApply if gsAfterApply.length != 1 then - throwError ( - "Failed to apply `ChangeOfVariables.toEquivalence`. " ++ - "Make sure that the change of variables is inferrable by type class resolution.") + throwChangeOfVariablesError ( + "failed to apply `ChangeOfVariables.toEquivalence`, " ++ + "make sure that the change of variables is inferrable by type class resolution.") -- Solve change of variables condition. let gCondition := gsAfterApply[0]! @@ -228,7 +229,7 @@ def changeOfVariablesBuilder (newVarStx varToChangeStx : TSyntax `ident) (← `(tactic| (simp [ChangeOfVariables.condition] at * <;> arith))) gCondition if gsFinal.length != 0 then trace[CvxLean.debug] "Could not prove {gsFinal}." - throwError "Failed to solve change of variables condition." + throwChangeOfVariablesError "failed to solve change of variables condition." end Meta diff --git a/CvxLean/Tactic/Basic/RewriteOpt.lean b/CvxLean/Tactic/Basic/RewriteOpt.lean index 9fa37d1e..09277963 100644 --- a/CvxLean/Tactic/Basic/RewriteOpt.lean +++ b/CvxLean/Tactic/Basic/RewriteOpt.lean @@ -9,11 +9,39 @@ import CvxLean.Tactic.Basic.RenameConstrs /-! # Tactics for conditional rewrites -TODO +This file defines the `rw_opt`, `rw_obj`, and `rw_constr` tactics. These tactics are used to modify +a particular component, usually by rewriting it. There is also the option of using `into` to +specify the target expression exactly, e.g., `rw_constr c₁ into (0 ≤ x + y)`. + +From the point of view of the user, this is similar `CvxLean/Tactic/Basic/ConvOpt.lean` in that this +tactic combinator allows us to focus on the objective function or one constraint, and apply tactics +to it. The difference is that the local context contains the constraints (other than the current one +if a constraint is being modified). This is crucial to apply rewrites that depend on conditions +established by other constraints. + +For example, consider the following: +``` +equivalence eqv/q : + optimization (x : ℝ) + minimize sqrt (x ^ 2) + subject to + c₁ : 0 ≤ x := by + rw_obj => + -- x : ℝ, c₁ : 0 ≤ x ⊢ sqrt (x ^ 2) = ?g x + rw [rpow_two, sqrt_sq c₁] +``` +In this case, we need `c₁` to be able to rewrite `sqrt (x ^ 2) = x`. The resulting problem `q` is: +``` +optimization (x : ℝ) + minimize x + subject to + c₁ : 0 ≤ x +``` ### TODO -* Clean up. +* Using the lemmas of the form `rewrite_constraint_i` is ugly because we need many of them. It does + improve performance compared to building the lemma every time. A better approach is needed here. -/ namespace CvxLean @@ -22,6 +50,8 @@ open Lean Meta Elab Parser Tactic namespace Meta +/-- Find appropriate lemma to rewrite the objective function depending on the number of +constraints, which makes it easy to add them to the local context by unification. -/ def rewriteObjLemma (numConstrs : Nat) : MetaM Name := match numConstrs with | 1 => return ``Minimization.Equivalence.rewrite_objFun_1 @@ -36,7 +66,10 @@ def rewriteObjLemma (numConstrs : Nat) : MetaM Name := | 10 => return ``Minimization.Equivalence.rewrite_objFun_10 | _ => throwRwObjError "can only rewrite problems with up to 10 constraints." -/-- -/ +/-- This is the main equivalence builder to rewrite the objective function. It takes a tactic that +can use both the current goal and the variable representing the domain of the problem. The latter +is used in `pre_dcp` to easily build expressions. It takes care of applying the correct "congruence" +lemma to set up the local context to rewrite the objective function, and applies the tactic. -/ def rewriteObjBuilderFromTactic (shouldEval : Bool) (tac : MVarId → FVarId → TacticM Unit) (rhs? : Option Expr) : EquivalenceBuilder Unit := fun eqvExpr g => g.withContext do @@ -67,7 +100,7 @@ def rewriteObjBuilderFromTactic (shouldEval : Bool) (tac : MVarId → FVarId → trace[CvxLean.debug] "`rw_obj` could not close {gs}." throwRwObjError "could not close all goals." -/-- -/ +/-- Wrapper of `rewriteObjBuilderFromTactic` that takes syntax instead. -/ def rewriteObjBuilder (shouldEval : Bool) (tacStx : Syntax) (rhs? : Option Expr) : EquivalenceBuilder Unit := fun eqvExpr g => do @@ -76,10 +109,13 @@ def rewriteObjBuilder (shouldEval : Bool) (tacStx : Syntax) (rhs? : Option Expr) let tac := fun g _ => evalTacticAt tacStx g >>= setGoals rewriteObjBuilderFromTactic shouldEval tac rhs? eqvExpr g +/-- Same as `rewriteObjBuilder` but without a target RHS. -/ def rewriteObjBuilderWithoutTarget (shouldEval : Bool) (tacStx : Syntax) : EquivalenceBuilder Unit := rewriteObjBuilder shouldEval tacStx none +/-- Same as `rewriteObjBuilder` but with a target RHS, given as syntax. Therefore, we need to do +some work to create an expression over the optimization variables. -/ def rewriteObjBuilderWithTarget (shouldEval : Bool) (tacStx : Syntax) (targetStx : Syntax) : EquivalenceBuilder Unit := fun eqvExpr g => g.withContext do @@ -92,7 +128,9 @@ def rewriteObjBuilderWithTarget (shouldEval : Bool) (tacStx : Syntax) (targetStx mkLambdaFVars #[p] ((target.replaceFVars fvars xs).replaceFVars xs prs) rewriteObjBuilder shouldEval tacStx (some target) eqvExpr g -/-- Returns lemma to rewrite constraints at `rwIdx` and the name of the RHS parameter. -/ +/-- Returns lemma to rewrite constraints at `rwIdx` and the name of the RHS parameter. Similarly to +`rewriteObjLemma`, it depends on the number of constraints to appropriately introduce the +constraints that the rewrite might depend on. -/ def rewriteConstrLemma (rwIdx : Nat) (numConstrs : Nat) : MetaM (Name × Name) := if rwIdx == numConstrs then match rwIdx with @@ -128,6 +166,7 @@ open Lean.Parser.Category Std.Tactic.RCases def nameToRintroPat (n : Name) : TSyntax `rcasesPat := Unhygienic.run `(rcasesPat| $(Lean.mkIdent n):ident) +/-- Trick to turn a list of names into a `rintro` pattern. -/ def namesToRintroPat (names : List Name) : MetaM (TSyntax `rcasesPat) := do let ns := names.map nameToRintroPat match ns with @@ -150,7 +189,9 @@ end RIntro open Term -/-- -/ +/-- This is the main equivalence builder to rewrite a constraint conditionally. It operates very +similarly to `rewriteObjBuilderFromTactic`. Some furhter work is needed here since we need to +carefully put every *other* constraint in the local context. -/ def rewriteConstrBuilderFromTactic (shouldEval : Bool) (constrTag : Name) (tac : MVarId → FVarId → TacticM Unit) (rhs? : Option Expr) : EquivalenceBuilder (List Name) := fun eqvExpr g => g.withContext do @@ -205,6 +246,7 @@ def rewriteConstrBuilderFromTactic (shouldEval : Bool) (constrTag : Name) return constrTags +/-- Wrapper of `rewriteConstrBuilderFromTactic` that takes syntax instead. -/ def rewriteConstrBuilder (shouldEval : Bool) (constrTag : Name) (tacStx : Syntax) (rhs? : Option Expr) : EquivalenceBuilder (List Name) := fun g eqvExpr => do @@ -213,10 +255,13 @@ def rewriteConstrBuilder (shouldEval : Bool) (constrTag : Name) (tacStx : Syntax let tac := fun g _ => evalTacticAt tacStx g >>= setGoals rewriteConstrBuilderFromTactic shouldEval constrTag tac rhs? g eqvExpr +/-- Same as `rewriteConstrBuilder` but without a target RHS. -/ def rewriteConstrBuilderWithoutTarget (shouldEval : Bool) (constrTag : Name) (tacStx : Syntax) : EquivalenceBuilder (List Name) := rewriteConstrBuilder shouldEval constrTag tacStx none +/-- Same as `rewriteConstrBuilder` but with a target RHS, given as syntax. Therefore, we need to do +some work to create an expression over the optimization variables. -/ def rewriteConstrBuilderWithTarget (shouldEval : Bool) (constrTag : Name) (tacStx : Syntax) (targetStx : Syntax) : EquivalenceBuilder (List Name) := fun eqvExpr g => g.withContext do @@ -233,6 +278,11 @@ end Meta namespace Tactic +/-- This tactic can be used with or without a target expression as follows: `rw_obj => ...` and +`rw_obj into rhs => ...`, where `rhs` is a term over the optimization variables. The `...` might be +replaced by any tactic block, usually rewrites. The result is an equivalent optimization problem +with the objective function modified as specified by the tactic block. The rewrites might be +conditional on the constraints. -/ syntax (name := rwObj) "rw_obj" (&" into" term)? darrow (tacticSeq)? : tactic @[tactic rwObj] @@ -247,6 +297,11 @@ def evalRwObjFun : Tactic := fun stx => match stx with (rewriteObjBuilderWithTarget true tacStx targetStx).toTactic | _ => throwUnsupportedSyntax +/-- This tactic can be used with or without a target expression as follows: `rw_constr c₁ => ...` +and `rw_constr c₁ into rhs => ...`, where `rhs` is a term over the optimization variables. The `...` +might be replaced by any tactic block, usually rewrites. The result is an equivalent optimization +problem with the constraint `c₁` modified as specified by the tactic block. The rewrites might be +conditional on the other constraints, e.g., `c₂`, ..., `cₙ`. -/ syntax (name := rwConstr) "rw_constr " (ident)? (&" into" term)? darrow (tacticSeq)? : tactic @[tactic rwConstr] diff --git a/CvxLean/Tactic/Basic/ShowVars.lean b/CvxLean/Tactic/Basic/ShowVars.lean index b11c009d..00d4e0b7 100644 --- a/CvxLean/Tactic/Basic/ShowVars.lean +++ b/CvxLean/Tactic/Basic/ShowVars.lean @@ -1,14 +1,18 @@ import CvxLean.Meta.Minimization -namespace CvxLean +/-! +This tactic should ideally not be used. In some cases, it could happen that the nice problem syntax +with named optimization variables is lost (e.g., simplifying at the wrong place). The `show_vars` +tactic can be used to recover the names. +-/ -open Lean +namespace CvxLean -open Meta Elab Tactic +open Lean Meta Elab Tactic namespace Meta -/-- -/ +/-- Essentially replace projections by the appropriate optimization variable names. -/ def showVars (goal : MVarId) (p : FVarId) : MetaM MVarId := do -- Revert variables after p. let (revertedVars, goal) ← MVarId.withContext goal do @@ -28,7 +32,7 @@ def showVars (goal : MVarId) (p : FVarId) : MetaM MVarId := do let mVar ← mkFreshExprMVar newTarget MetavarKind.syntheticOpaque MVarId.assign goal (← mkLetFVars xs mVar) return mVar.mvarId! - -- Reintroduce reverted variables. + -- Re-introduce reverted variables. let (_, goal) ← MVarId.introNP goal revertedVars.size return goal @@ -36,27 +40,27 @@ end Meta namespace Tactic +/-- Show the names of the optimization variables (useful if they have been lost for some reason). -/ syntax (name := showVars) "show_vars" ident : tactic -/-- -/ @[tactic showVars] -partial def evalShowVars : Tactic -| `(tactic| show_vars $p) => do - replaceMainGoal [← Meta.showVars (← getMainGoal) (← withMainContext do getFVarId p)] -| _ => throwUnsupportedSyntax +def evalShowVars : Tactic + | `(tactic| show_vars $p) => do + replaceMainGoal [← Meta.showVars (← getMainGoal) (← withMainContext do getFVarId p)] + | _ => throwUnsupportedSyntax end Tactic namespace Tactic.Conv +/-- Like `show_vars` but as a conv tactic. -/ syntax (name := showVars) "show_vars" ident : conv -/-- Like the basic `show_vars` but as a conv tactic. -/ @[tactic showVars] -partial def evalShowVars : Tactic -| `(conv| show_vars $p) => do - replaceMainGoal [← Meta.showVars (← getMainGoal) (← withMainContext do getFVarId p)] -| _ => throwUnsupportedSyntax +def evalShowVars : Tactic + | `(conv| show_vars $p) => do + replaceMainGoal [← Meta.showVars (← getMainGoal) (← withMainContext do getFVarId p)] + | _ => throwUnsupportedSyntax end Tactic.Conv diff --git a/CvxLean/Tactic/DCP/AtomCmd.lean b/CvxLean/Tactic/DCP/AtomCmd.lean index 7a099462..b396015b 100644 --- a/CvxLean/Tactic/DCP/AtomCmd.lean +++ b/CvxLean/Tactic/DCP/AtomCmd.lean @@ -1,829 +1,253 @@ import CvxLean.Meta.Util.Expr +import CvxLean.Meta.Util.Error import CvxLean.Meta.Minimization import CvxLean.Tactic.DCP.AtomExt import CvxLean.Tactic.DCP.AtomSyntax -import CvxLean.Lib.Math.Data.Real -import CvxLean.Tactic.DCP.DCP +import CvxLean.Tactic.DCP.AtomCmdElab +import CvxLean.Tactic.DCP.AtomCmdMultiLevel /-! # Atom library (command) -TODO - -### TODO - -* Make `vconds` available in `impConstrs` and `impObj`, in the same way that `bconds` are in - `elabImpConstrs`. This needs to be synchronised with the DCP procedure. +This file defines all the versions of the `declare_atom` command. First, we allow atoms with or +without background conditions. Affine atoms are treated separately. Note that they do not really +have a graph implementation, they are already in conic form. We add to extra properties to make sure +that they are indeed affine, which are used only as a sanity check. In our framework, cones are also +atoms, which are again trivial, i.e., they are their own implementation. Thus, we also have a +special command for cones. -/ namespace CvxLean open Lean Expr Meta Elab Command --- TODO: This does not respect namespaces. -/-- Introduce new names for the proofs in the atom to speed up proof building later. -/ -def addAtomDataDecls (id : Lean.Name) (atomData : GraphAtomData) : CommandElabM GraphAtomData := do - if atomData.solEqAtom.hasMVar then - throwError "has mvar {atomData.solEqAtom}" - let solEqAtom ← addAtomDataDecl (id.mkStr "solEqAtom") atomData.solEqAtom - let optimality ← addAtomDataDecl (id.mkStr "optimality") atomData.optimality - let mut feasibility := #[] - for i in [:atomData.feasibility.size] do - feasibility := feasibility.push $ - ← addAtomDataDecl (id.mkStr (s!"feasibility{i}")) atomData.feasibility[i]! - let mut vcondElim := #[] - for i in [:atomData.vcondElim.size] do - vcondElim := vcondElim.push $ - ← addAtomDataDecl (id.mkStr (s!"vcondElim{i}")) atomData.vcondElim[i]! - return {atomData with - solEqAtom := solEqAtom - feasibility := feasibility - vcondElim := vcondElim - optimality := optimality - } -where - addAtomDataDecl (name : Lean.Name) (expr : Expr) : CommandElabM Expr := do - liftCoreM <| addDecl <| .defnDecl { - name := name - levelParams := [] - type := ← liftTermElabM (do return ← inferType expr) - value := expr - hints := .regular (getMaxHeight (← getEnv) expr + 1) - safety := .safe - } - return mkConst name - - -/-- -/ -def withCopyOfMonoXs {α} [Inhabited α] (xs : Array Expr) (argKinds : Array ArgKind) (f : Array Expr → Array Expr → Array ArgKind → TermElabM α) : TermElabM α := do - -- Determine subset of monotone arguments and their behavior. - let mut argDeclInfo : Array (Lean.Name × (Array Expr → TermElabM Expr)) := #[] - let mut monoXs := #[] - let mut monoArgKind := #[] - for i in [:xs.size] do - if argKinds[i]! != ArgKind.Constant ∧ argKinds[i]! != ArgKind.Neither then - let ty := ← inferType xs[i]! - let name := (← FVarId.getDecl xs[i]!.fvarId!).userName - argDeclInfo := argDeclInfo.push (name, fun _ => pure ty) - monoXs := monoXs.push xs[i]! - monoArgKind := monoArgKind.push argKinds[i]! - - withLocalDeclsD argDeclInfo fun ys => do - return ← f monoXs ys monoArgKind - -/-- -/ -def shiftingArgs (curv : Curvature) (xs : Array Expr) (argKinds : Array ArgKind) (mkConcl : Array Expr → Array Expr → TermElabM Expr) : TermElabM Expr := - withCopyOfMonoXs xs argKinds fun monoXs ys monoArgKind => do - let mut ty := ← mkConcl monoXs ys - for i' in [:monoXs.size] do - let i := monoXs.size - 1 - i' - ty ← match curvatureInArg curv monoArgKind[i]! with - | Curvature.Concave => mkArrow (← mkLe monoXs[i]! ys[i]!) ty - | Curvature.Convex => mkArrow (← mkLe ys[i]! monoXs[i]!) ty - | _ => throwError "invalid curvature" - return ← mkForallFVars ys ty - -/-- TODO: just the number of relevant opt args. -/ -def getMonoArgsCount (curv : Curvature) (argKinds : Array ArgKind) : ℕ := - argKinds.foldl (fun acc argKind => - match curvatureInArg curv argKind with - | Curvature.Concave => 1 + acc - | Curvature.Convex => 1 + acc - | _ => acc) 0 - -/-- Use the DCP procedure to canon the graph implementation to a problem that -uses only cone constraints and affine atoms. -/ -def canonAtomData (objCurv : Curvature) (atomData : GraphAtomData) : CommandElabM GraphAtomData := do - liftTermElabM do - -- `xs` are the arguments of the atom. - lambdaTelescope atomData.expr fun xs atomExpr => do - let atomRange ← inferType atomExpr - - -- Call DCP on graph implementation. - let (objFun, constraints, originalVarsDecls) ← do - let impVars := atomData.impVars.map fun (n, ty) => (n, mkAppNBeta ty xs) - withLocalDeclsDNondep impVars fun vs => do - let vsDecls ← vs.mapM fun v => v.fvarId!.getDecl - let bconds := atomData.bconds.map fun (n,c) => (n, mkAppNBeta c xs) - withLocalDeclsDNondep bconds fun bs => do - let originalVarsDecls := vsDecls - let objFun := mkAppNBeta (mkAppNBeta atomData.impObjFun xs) vs - trace[Meta.debug] "{xs} {bs} {vs}" - let constraints := atomData.impConstrs.map - fun c => (`_, mkAppNBeta (mkAppNBeta (mkAppNBeta c xs) bs) vs) - trace[Meta.debug] "constraints in with: {constraints}" - return (objFun, constraints, originalVarsDecls) - - let xsVarsPre := xs.map fun x => x.fvarId! - let mut xsVars := #[] - for i in [:xs.size] do - if atomData.argKinds[i]! != ArgKind.Constant then - xsVars := xsVars.push xsVarsPre[i]! - trace[Meta.debug] "xsVars: {(← xsVars.mapM (·.getDecl)).map (·.userName)}" - - trace[Meta.debug] "before PAT " - trace[Meta.debug] "abo : {atomData.impConstrs}" - trace[Meta.debug] "constrs: {constraints}" - -- Add bconds declarations. - let bconds := atomData.bconds.map fun (n,c) => (n, mkAppNBeta c xs) - let pat ← withLocalDeclsDNondep bconds fun bs => do - let bcondsDecls ← bs.mapM (·.fvarId!.getDecl) - DCP.mkProcessedAtomTree objCurv objFun constraints.toList originalVarsDecls - (extraVars := xsVars) (extraDecls := bcondsDecls) - - trace[Meta.debug] "after PAT " - -- `pat` is the atom tree resulting from the DCP procedure. - - -- Temporary check until using atoms in graph implementations is supported. - -- if pat.newVarDecls.length ≠ 0 ∨ pat.newConstrs.size ≠ 0 then - -- throwError "Using nontrivial atoms in graph implementations is not yet supported" - - trace[Meta.debug] "pat.originalVarsDecls : {pat.originalVarsDecls.map (·.userName)}" - trace[Meta.debug] "pat.newVarDecls : {pat.newVarDecls.map (·.userName)}" - trace[Meta.debug] "pat.newConstrVarsArray : {pat.newConstrVarsArray.map (·.userName)}" - - withExistingLocalDecls originalVarsDecls.toList do - withExistingLocalDecls pat.newVarDecls do - withExistingLocalDecls pat.newConstrVarsArray.toList do - trace[Meta.debug] "pat opt: {pat.optimality}" - for c in pat.optimality.constr.map Tree.val do - trace[Meta.debug] "pat opt constr: {c}" - check c - -- `vs1` are the variables already present in the uncanon graph implementation. - let vs1 := originalVarsDecls.map (mkFVar ·.fvarId) - -- `vs2` are the variables to be added to the graph implementation. - let vs2 := pat.newVarDecls.toArray.map (mkFVar ·.fvarId) - let vs1Sol := atomData.solution.map fun s => mkAppNBeta s xs - trace[Meta.debug] "vs1: {vs1}" - trace[Meta.debug] "vs2: {vs2}" - trace[Meta.debug] "vs1Sol: {vs1Sol}" - trace[Meta.debug] "xs: {xs}" - - -- TODO: move because copied from DCP tactic. - let canonConstrs := pat.canonExprs.constr.map Tree.val - let canonConstrs := canonConstrs.filterIdx (fun i => ¬ pat.isVCond[i]!) - - -- TODO: move because copied from DCP tactic. - let newConstrProofs := pat.feasibility.fold #[] fun acc fs => - fs.fold acc Array.append - - -- for sc in solEqAtomProofs do - -- trace[Meta.debug] "solEqAtomProofs: {← inferType sc}" - - for ncp in newConstrProofs do - trace[Meta.debug] "newConstrProofs: {← inferType ncp}" - - -- let test := ← atomData.feasibility.mapM (fun e => - -- Meta.forallTelescope e (fun xs a => do - -- trace[Meta.debug] "a: {←inferType a}")) - - let vconds := atomData.vconds.map fun (n,c) => (n, mkAppNBeta c xs) - let bconds := atomData.bconds.map fun (n,c) => (n, mkAppNBeta c xs) - - let solEqAtomProofs := pat.solEqAtom.constr.map Tree.val - - if atomData.feasibility.size != solEqAtomProofs.size then - throwError ("Expected same length: {atomData.feasibility} and " ++ - "{solEqAtomProofs}") - - let solEqAtomProofs := pat.solEqAtom.constr.map Tree.val - let mut oldFeasibilityAdjusted := #[] - - for i in [:atomData.feasibility.size] do - let feas := atomData.feasibility[i]! - let feasXs := mkAppNBeta feas xs - let adjustedFeas ← - withLocalDeclsDNondep vconds fun cs => do - withLocalDeclsDNondep bconds fun bs => do - let feasXsVconds := mkAppNBeta feasXs cs - let feasXsConds := mkAppNBeta feasXsVconds bs - let proofAdjusted := solEqAtomProofs[i]!.replaceFVars vs1 vs1Sol - let adjustedFeas ← mkAppM ``Eq.mpr #[proofAdjusted, feasXsConds] - mkLambdaFVars (xs ++ cs ++ bs) adjustedFeas - oldFeasibilityAdjusted := oldFeasibilityAdjusted.push adjustedFeas - - for proof in solEqAtomProofs do - trace[Meta.debug] "solEqAtomProofs: {← inferType proof}" - - for feas in oldFeasibilityAdjusted do - trace[Meta.debug] "oldFeasibilityAdjusted: {← inferType feas}" - - let newFeasibility ← newConstrProofs.mapM (fun e => - withLocalDeclsDNondep vconds fun cs => do - withLocalDeclsDNondep bconds fun bs => do - mkLambdaFVars (xs ++ cs ++ bs) (e.replaceFVars vs1 vs1Sol)) - - for f in atomData.feasibility do - trace[Meta.debug] "feasibility: {← inferType f}" - - for nf in newFeasibility do - trace[Meta.debug] "newFeasibility: {← inferType nf}" - - let constraintsFromCanonConstraints := - pat.optimality.constr.map Tree.val - - for cfrc in constraintsFromCanonConstraints do - trace[Meta.debug] "constraintsFromCanonConstraints: {← inferType cfrc}" - - if canonConstrs.size != constraintsFromCanonConstraints.size then - throwError ("Expected same length: {canonConstrs} and " ++ - "{constraintsFromCanonConstraints}") - - let objFunFromCanonObjFun := pat.optimality.objFun.val - trace[Meta.debug] "objFunFromCanonObjFun: {← inferType objFunFromCanonObjFun}" - - trace[Meta.debug] "pat.optimality.objFun: {← inferType atomData.optimality}" - - let newOptimality ← - withLocalDeclsDNondep bconds fun bs => do - let optimalityXsBConds := mkAppN atomData.optimality (xs ++ bs ++ vs1) - trace[Meta.debug] "newOptimality: {← inferType optimalityXsBConds}" - withLocalDeclsDNondep (canonConstrs.map (fun rc => (`_, rc))) fun cs => do - -- First, apply all constraints. - let mut optimalityAfterCanon := optimalityXsBConds - for i in [:canonConstrs.size] do - trace[Meta.debug] "optimalityAfterCanon: {← inferType optimalityAfterCanon}" - let c := mkApp constraintsFromCanonConstraints[i]! cs[i]! - optimalityAfterCanon := mkApp optimalityAfterCanon c - -- Then, adjust the condition using `objFunFromCanonObjFun`. - trace[Meta.debug] "optimalityAfterCanon: {← inferType optimalityAfterCanon}" - let monoArgsCount := getMonoArgsCount objCurv atomData.argKinds - let optimalityAfterApplyWithConditionAdjusted ← - lambdaTelescope (← whnf optimalityAfterCanon) <| fun xs e => do - -- Every extra argument has an extra condition, e.g. x', x ≤ x. - trace[Meta.debug] "xs: {xs}" - let monoArgs := xs[:2 * monoArgsCount] - trace[Meta.debug] "monoArgs: {monoArgs}" - trace[Meta.debug] "e: {← inferType e}" - let optCondition ← mkLambdaFVars xs[2 * monoArgsCount:] e - let newCond ← - if atomData.curvature == Curvature.Convex then - mkAppOptM ``le_trans #[ - atomRange, none, none, none, none, - optCondition, objFunFromCanonObjFun] - else - -- TODO: concave. but convex_set too? - mkAppOptM ``le_trans #[ - atomRange, none, none, none, none, - objFunFromCanonObjFun, optCondition] - mkLambdaFVars monoArgs newCond - - trace[Meta.debug] "optimalityAfterApplyWithConditionAdjusted: {← inferType optimalityAfterApplyWithConditionAdjusted}" - trace[Meta.debug] "newOptimality applied: {← inferType optimalityAfterCanon}" - let ds := pat.newConstrVarsArray.map (mkFVar ·.fvarId) - mkLambdaFVars (xs ++ bs ++ vs1 ++ vs2 ++ cs ++ ds) optimalityAfterApplyWithConditionAdjusted - - trace[Meta.debug] "newOptimality: {← inferType newOptimality}" - - let mut newVCondElims := #[] - for vcondElim in atomData.vcondElim do - let newVCondElim := mkAppN vcondElim (xs ++ vs1) - let newVCondElim ← - withLocalDeclsDNondep (canonConstrs.map (fun rc => (`_, rc))) fun cs => do - let mut newVCondElim := newVCondElim - for i in [:canonConstrs.size] do - let c := mkApp constraintsFromCanonConstraints[i]! cs[i]! - newVCondElim := mkApp newVCondElim c - let ds := pat.newConstrVarsArray.map (mkFVar ·.fvarId) - mkLambdaFVars (xs ++ vs1 ++ vs2) <| ← mkLambdaFVars (cs ++ ds) newVCondElim - newVCondElims := newVCondElims.push newVCondElim - - let atomData' : GraphAtomData := { atomData with - impVars := atomData.impVars.append - (← pat.newVarDecls.toArray.mapM fun decl => do - return (decl.userName, ← mkLambdaFVars xs decl.type)) - impObjFun := ← mkLambdaFVars xs $ ← mkLambdaFVars (vs1 ++ vs2) $ - pat.canonExprs.objFun.val - impConstrs := ← (canonConstrs ++ pat.newConstrs).mapM - (fun c => do - withLocalDeclsDNondep bconds fun bs => do - return ← mkLambdaFVars xs <| ← mkLambdaFVars bs <| ← mkLambdaFVars (vs1 ++ vs2) c) - solution := atomData.solution.append - (← pat.forwardImagesNewVars.mapM (fun e => mkLambdaFVars xs - (e.replaceFVars vs1 vs1Sol))) - feasibility := oldFeasibilityAdjusted ++ newFeasibility - optimality := newOptimality - vcondElim := newVCondElims - - } - - -- TODO: solEqAtom := sorry. ????????? - - return atomData' - -/-- -/ -def elabCurvature (curv : Syntax) : TermElabM Curvature := do - match curv.getId with - | `«convex» => return Curvature.Convex - | `«concave» => return Curvature.Concave - | `«affine» => return Curvature.Affine - | `«convex_set» => return Curvature.ConvexSet - | _ => throwError "Unknown curvature : {curv.getId}" - -/-- -/ -def elabExpr (expr : Syntax) (argDecls : Array LocalDecl) (ty : Option Expr := none): TermElabM (Expr × Expr) := do - withExistingLocalDecls argDecls.toList do - let xs := argDecls.map (mkFVar ·.fvarId) - let body ← Elab.Term.elabTermAndSynthesizeEnsuringType expr ty - let bodyTy ← inferType body - return (← mkLambdaFVars xs body, bodyTy) - -/-- -/ -def elabArgKindsAux : List Syntax → TermElabM (List LocalDecl × List ArgKind) -| [] => pure ([], []) -| List.cons stx stxs => do - match stx with - | `(arg_with_kind|($id : $ty) $argkind) => do - let ty ← Term.elabTerm ty.raw none - withLocalDeclD id.getId ty fun x => do - let argkind ← - match argkind.raw with - | `(arg_kind| +) => pure $ ArgKind.Increasing - | `(arg_kind| -) => pure $ ArgKind.Decreasing - | `(arg_kind| ?) => pure $ ArgKind.Neither - | `(arg_kind| &) => pure $ ArgKind.Constant - | _ => throwError "Unknown argument kind {argkind}" - let argDecl ← x.fvarId!.getDecl - let (argDecls, argKinds) ← elabArgKindsAux stxs - return (argDecl :: argDecls, argkind :: argKinds) - | _ => throwUnsupportedSyntax - -/-- -/ -def elabArgKinds (args : Array Syntax) : TermElabM (Array LocalDecl × Array ArgKind) := do - let (argDecls, argKinds) ← elabArgKindsAux args.toList - return (argDecls.toArray, argKinds.toArray) - -/-- -/ -def elabVConditions (argDecls : Array LocalDecl) (vcondStx : Array Syntax) : - TermElabM (Array (Lean.Name × Expr) × Std.HashMap Lean.Name Expr) := do - withExistingLocalDecls argDecls.toList do - let xs := argDecls.map (mkFVar ·.fvarId) - let mut vcondMap : Std.HashMap Lean.Name Expr := {} - let mut vconds : Array (Lean.Name × Expr) := #[] - for stx in vcondStx do - let vcond ← Elab.Term.elabTermAndSynthesizeEnsuringType stx[3] (some $ mkSort levelZero) - let vcond ← mkLambdaFVars xs vcond - vcondMap := vcondMap.insert stx[1].getId vcond - vconds := vconds.push (stx[1].getId, vcond) - return (vconds, vcondMap) - -/-- Assumtions can be elaborated exactly like vconditions. -/ -def elabBConds := elabVConditions - -/-- -/ -def elabImpVars (argDecls : Array LocalDecl) (impVarsStx : Array Syntax) : - TermElabM (Array (Lean.Name × Expr) × Std.HashMap Lean.Name Expr) := do - withExistingLocalDecls argDecls.toList do - let xs := argDecls.map (mkFVar ·.fvarId) - let mut impVarMap : Std.HashMap Lean.Name Expr := {} - let mut impVars : Array (Lean.Name × Expr) := #[] - for stx in impVarsStx do - let impVarTy ← Elab.Term.elabTermAndSynthesizeEnsuringType stx[3] none - let impVarTy ← mkLambdaFVars xs $ impVarTy - impVars := impVars.push (stx[1].getId, impVarTy) - impVarMap := impVarMap.insert stx[1].getId impVarTy - return (impVars, impVarMap) - -/-- -/ -def elabImpObj (argDecls : Array LocalDecl) (impVars : Array (Lean.Name × Expr)) - (impObjStx : Syntax) (bodyTy : Expr) : TermElabM Expr := do - withExistingLocalDecls argDecls.toList do - let xs := argDecls.map (mkFVar ·.fvarId) - withLocalDeclsDNondep (impVars.map fun iv => (iv.1, mkAppNBeta iv.2 xs)) fun ts => do - let impObj ← Elab.Term.elabTermAndSynthesizeEnsuringType impObjStx (some bodyTy) - return ← mkLambdaFVars xs $ ← mkLambdaFVars ts impObj - -/-- -/ -def elabImpConstrs (argDecls : Array LocalDecl) (impVars : Array (Lean.Name × Expr)) - (bconds : Array (Lean.Name × Expr)) (impConstrStx : Array Syntax) : - TermElabM (Array (Lean.Name × Expr)) := do - withExistingLocalDecls argDecls.toList do - let xs := argDecls.map (mkFVar ·.fvarId) - let bconds := bconds.map fun (n,c) => (n, mkAppNBeta c xs) - withLocalDeclsDNondep bconds fun as => do - withLocalDeclsDNondep (impVars.map fun iv => (iv.1, mkAppNBeta iv.2 xs)) fun vs => do - let mut impConstrs : Array (Lean.Name × Expr) := #[] - for stx in impConstrStx do - let impConstr ← Elab.Term.elabTermAndSynthesizeEnsuringType stx[3] (mkSort levelZero) - let impConstr ← mkLambdaFVars xs <| ← mkLambdaFVars as <| ← mkLambdaFVars vs impConstr - trace[Meta.debug] "impConstr: {← inferType impConstr}" - impConstrs := impConstrs.push (stx[1].getId, impConstr) - return impConstrs - -/-- -/ -def elabSols (argDecls : Array LocalDecl) (impVars : Array (Lean.Name × Expr)) - (impVarMap : Std.HashMap Lean.Name Expr) (solStx : Array Syntax) : TermElabM (Array Expr) := do - withExistingLocalDecls argDecls.toList do - let xs := argDecls.map (mkFVar ·.fvarId) - let impVarNames := impVars.map (·.fst) - let mut solMap : Std.HashMap Lean.Name Expr := {} - for stx in solStx do - let id ← match impVarMap.find? stx[1].getId with - | some id => pure id - | none => throwError "Unknown variable: {stx[1].getId}" - let ty := mkAppNBeta id xs - solMap := solMap.insert stx[1].getId $ ← Elab.Term.elabTermAndSynthesizeEnsuringType stx[3] (some ty) - let sols ← impVarNames.mapM - fun n => match solMap.find? n with - | some sol => pure sol - | none => throwError "solution not found {n}" - let sols ← sols.mapM (do return ← mkLambdaFVars xs ·) - return sols - -/-- -/ -def elabSolEqAtom (argDecls : Array LocalDecl) (vconds : Array (Lean.Name × Expr)) - (impObj : Expr) (sols : Array Expr) - (expr : Expr) (stx : Syntax) : TermElabM Expr := do - withExistingLocalDecls argDecls.toList do - let xs := argDecls.map (mkFVar ·.fvarId) - let vconds := vconds.map fun (n,c) => (n, mkAppNBeta c xs) - withLocalDeclsDNondep vconds fun cs => do - let body := mkAppNBeta expr xs - let impObj := mkAppNBeta impObj xs - let sols := sols.map (mkAppNBeta · xs) - let impObj' := convertLambdasToLets impObj sols - let ty ← mkEq impObj' body - trace[Meta.debug] "ensuring type {ty}" - let solEqAtom ← Elab.Term.elabTermAndSynthesizeEnsuringType stx (some ty) - return ← mkLambdaFVars xs $ ← mkLambdaFVars cs solEqAtom - -/-- -/ -def elabFeas (argDecls : Array LocalDecl) (vconds bconds : Array (Lean.Name × Expr)) - (impConstrs : Array (Lean.Name × Expr)) (sols : Array Expr) (stx : Array Syntax) : TermElabM (Array Expr) := do - withExistingLocalDecls argDecls.toList do - let xs := argDecls.map (mkFVar ·.fvarId) - let bconds := bconds.map fun (n, c) => (n, mkAppNBeta c xs) - withLocalDeclsDNondep bconds fun as => do - let vconds := vconds.map fun (n, c) => (n, mkAppNBeta c xs) - withLocalDeclsDNondep vconds fun cs => do - let impConstrs := impConstrs.map fun (n, c) => (n, mkAppNBeta c xs) - let impConstrs := impConstrs.map fun (n, c) => (n, mkAppNBeta c as) - let sols := sols.map (mkAppNBeta · xs) - let mut feas := #[] - for i in [:impConstrs.size] do - if (stx[i]!)[1]!.getId != impConstrs[i]!.1 then - throwError "Feasibility: Expected {impConstrs[i]!.1}, found {(stx[i]!)[1]!}" - let ty := convertLambdasToLets impConstrs[i]!.2 sols - let f ← Elab.Term.elabTermAndSynthesizeEnsuringType (stx[i]!)[3]! (some ty) - let f ← mkLambdaFVars xs $ ← mkLambdaFVars as $ ← mkLambdaFVars cs f - feas := feas.push f - return feas - -/-- -/ -def withCopyOfNonConstVars (xs : Array Expr) (argKinds : Array ArgKind) (f : Array Expr → Array Expr → TermElabM Expr) : - TermElabM Expr := do - -- Determine subset of non-constant arguments. - let mut argDeclInfo : Array (Lean.Name × (Array Expr → TermElabM Expr)) := #[] - for i in [:xs.size] do - if argKinds[i]! != ArgKind.Constant then - let ty := ← inferType xs[i]! - let name := Name.mkSimple ((ToString.toString (← FVarId.getDecl xs[i]!.fvarId!).userName) ++ "'") - argDeclInfo := argDeclInfo.push (name, fun _ => pure ty) - - withLocalDeclsD argDeclInfo fun ys => do - let mut allYs := #[] - let mut j := 0 - for i in [:xs.size] do - -- use old variable if constant, otherwise use new variable - if argKinds[i]! == ArgKind.Constant then - allYs := allYs.push xs[i]! - else - allYs := allYs.push ys[j]! - j := j + 1 - return ← f ys allYs - - - -/-- -/ -def elabOpt (curv : Curvature) (argDecls : Array LocalDecl) (expr : Expr) (bconds : Array (Lean.Name × Expr)) - (impVars : Array (Lean.Name × Expr)) (impObj : Expr) (impConstrs : Array (Lean.Name × Expr)) - (argKinds : Array ArgKind) (stx : Syntax) : TermElabM Expr := do - withExistingLocalDecls argDecls.toList do - let xs := argDecls.map (mkFVar ·.fvarId) - let bconds := bconds.map fun (n,c) => (n, mkAppNBeta c xs) - withLocalDeclsDNondep bconds fun as => do - withLocalDeclsDNondep (impVars.map fun iv => (iv.1, mkAppNBeta iv.2 xs)) fun vs => do - let impConstrs := impConstrs.map fun (n,c) => (n, mkAppNBeta c xs) - let impConstrs := impConstrs.map fun (n,c) => (n, mkAppNBeta c as) - let impConstrs := impConstrs.map fun (n,c) => (n, mkAppNBeta c vs) - withLocalDeclsDNondep impConstrs fun is => do - let impObj := mkAppNBeta (mkAppNBeta impObj xs) vs - let ty ← shiftingArgs curv xs argKinds fun monoXs ys => - let body := mkAppNBeta expr xs - let body := body.replaceFVars monoXs ys - match curv with - | Curvature.Concave => return ← mkLe impObj body - | Curvature.Convex => return ← mkLe body impObj - | _ => throwError "invalid curvature: {curv}" - trace[Meta.debug] "elabOpt ensuring {ty}" - let opt ← Elab.Term.elabTermAndSynthesizeEnsuringType stx (some ty) - trace[Meta.debug] "elabOpt opt: {← inferType opt}" - return ← mkLambdaFVars xs $ ← mkLambdaFVars as $ ← mkLambdaFVars vs $ ← mkLambdaFVars is opt - -/-- -/ -def elabVCondElim (curv : Curvature) (argDecls : Array LocalDecl) (bconds vconds : Array (Lean.Name × Expr)) (vcondMap : Std.HashMap Lean.Name Expr) - (impVars : Array (Lean.Name × Expr)) (impConstrs : Array (Lean.Name × Expr)) (argKinds : Array ArgKind) (stx : Array Syntax) : TermElabM (Array Expr) := do - withExistingLocalDecls argDecls.toList do - let xs := argDecls.map (mkFVar ·.fvarId) - let impConstrs := impConstrs.map fun (n,c) => (n, mkAppNBeta c xs) - let bconds := bconds.map fun (n,c) => (n, mkAppNBeta c xs) - withLocalDeclsDNondep bconds fun as => do - let impConstrs := impConstrs.map fun (n,c) => (n, mkAppNBeta c as) - withLocalDeclsDNondep (impVars.map fun iv => (iv.1, mkAppNBeta iv.2 xs)) fun vs => do - let impConstrs := impConstrs.map fun (n,c) => (n, mkAppNBeta c vs) - withLocalDeclsDNondep impConstrs fun is => do - let mut vcondElimMap : Std.HashMap Lean.Name Expr := {} - for i in [:stx.size] do - let ty ← shiftingArgs curv xs argKinds fun monoXs ys => do - let id ← match vcondMap.find? (stx[i]!)[1]!.getId with - | some id => pure id - | none => throwError "Unknown vcondition: {(stx[i]!)[1]!.getId}" - let body := mkAppNBeta id xs - let body := body.replaceFVars monoXs ys - return body - let vcondElim ← Elab.Term.elabTermAndSynthesizeEnsuringType (stx[i]!)[3]! (some $ ty) - let vcondElim ← mkLambdaFVars xs $ ← mkLambdaFVars vs $ ← mkLambdaFVars is vcondElim - vcondElimMap := vcondElimMap.insert (stx[i]!)[1]!.getId vcondElim - return ← vconds.mapM - fun (n, _) => match vcondElimMap.find? n with - | some vcond => pure vcond - | none => throwError "vcondition not found: {n}" - -/-- -/ +/-- Standard `declare_atom` command (without background conditions). -/ @[command_elab atomCommand] unsafe def elabAtomCommand : CommandElab -| `(declare_atom $id [ $curv ] $args* : $expr := - vconditions $vconds* - implementationVars $impVars* - implementationObjective $impObj - implementationConstraints $impConstrs* - solution $sols* - solutionEqualsAtom $solEqAtom - feasibility $feas* - optimality $opt - vconditionElimination $vcondElim*) => do - let atomData ← liftTermElabM do - -- TODO: check if trating convexset as convex makes sense. - let curvTag ← elabCurvature curv.raw - let curv := if curvTag == Curvature.ConvexSet then Curvature.Concave else curvTag - let (argDecls, argKinds) ← elabArgKinds args.rawImpl - let (expr, bodyTy) ← elabExpr expr.raw argDecls - let (vconds, vcondMap) ← elabVConditions argDecls vconds.rawImpl - let (impVars, impVarMap) ← elabImpVars argDecls impVars.rawImpl - let impObj ← elabImpObj argDecls impVars impObj.raw bodyTy - let impConstrs ← elabImpConstrs argDecls impVars #[] impConstrs.rawImpl - let sols ← elabSols argDecls impVars impVarMap sols.rawImpl - let solEqAtom ← elabSolEqAtom argDecls vconds impObj sols expr solEqAtom.raw - let feas ← elabFeas argDecls vconds #[] impConstrs sols feas.rawImpl - let opt ← elabOpt curv argDecls expr #[] impVars impObj impConstrs argKinds opt.raw - let vcondElim ← elabVCondElim curv argDecls #[] vconds vcondMap impVars impConstrs argKinds vcondElim.rawImpl - - let atomData := { - id := id.getId - curvature := curvTag - expr := expr - argKinds := argKinds - bconds := #[] - vconds := vconds - impVars := impVars - impObjFun := impObj - impConstrs := impConstrs.map (·.snd) - solution := sols - solEqAtom := solEqAtom - feasibility := feas - optimality := opt - vcondElim := vcondElim - } - return atomData - - let objCurv := atomData.curvature - -- match atomData.curvature with - -- | Curvature.ConvexSet => Curvature.ConvexSet - -- | _ => Curvature.Affine - - trace[Meta.debug] "before canon sol eq atom: {atomData.solEqAtom}" - - let atomData ← canonAtomData objCurv atomData - -- trace[Meta.debug] "HERE Canon atom: {atomData}" - let atomData ← addAtomDataDecls id.getId atomData - -- trace[Meta.debug] "HERE Added atom" - - liftTermElabM do - trace[Meta.debug] "Add atom: {atomData}" - addAtom $ AtomData.graph atomData -| _ => throwUnsupportedSyntax - + | `(declare_atom $id [ $curv ] $args* : $expr := + vconditions $vconds* + implementationVars $impVars* + implementationObjective $impObj + implementationConstraints $impConstrs* + solution $sols* + solutionEqualsAtom $solEqAtom + feasibility $feas* + optimality $opt + vconditionElimination $vcondElim*) => do + let atomData ← liftTermElabM do + let curv ← elabCurvature curv.raw + let (argDecls, argKinds) ← elabArgKinds args.rawImpl + let (expr, bodyTy) ← elabExpr expr.raw argDecls + let (vconds, vcondMap) ← elabVConditions argDecls vconds.rawImpl + let (impVars, impVarMap) ← elabImpVars argDecls impVars.rawImpl + let impObj ← elabImpObj argDecls impVars impObj.raw bodyTy + let impConstrs ← elabImpConstrs argDecls impVars #[] impConstrs.rawImpl + let sols ← elabSols argDecls impVars impVarMap sols.rawImpl + let solEqAtom ← elabSolEqAtom argDecls vconds impObj sols expr solEqAtom.raw + let feas ← elabFeas argDecls vconds #[] impConstrs sols feas.rawImpl + let opt ← elabOpt curv argDecls expr #[] impVars impObj impConstrs argKinds opt.raw + let vcondElim ← elabVCondElim curv argDecls #[] vconds vcondMap impVars impConstrs argKinds + vcondElim.rawImpl + + let atomData := { + id := id.getId + curvature := curv + expr := expr + argKinds := argKinds + bconds := #[] + vconds := vconds + impVars := impVars + impObjFun := impObj + impConstrs := impConstrs.map (·.snd) + solution := sols + solEqAtom := solEqAtom + feasibility := feas + optimality := opt + vcondElim := vcondElim + } + return atomData + + let atomData ← liftTermElabM <| canonAtomData atomData.curvature atomData + let atomData ← addAtomDataDecls id.getId atomData + + liftTermElabM do + trace[CvxLean.debug] "Adding atom data {atomData}." + addAtom <| AtomData.graph atomData + | _ => throwUnsupportedSyntax -/-- -/ +/-- Standard `declare_atom` command (with background conditions) -/ @[command_elab atomWithBCondsCommand] unsafe def elabAtomWithBCondsCommand : CommandElab -| `(declare_atom $id [ $curv ] $args* : $expr := - bconditions $bconds* - vconditions $vconds* - implementationVars $impVars* - implementationObjective $impObj - implementationConstraints $impConstrs* - solution $sols* - solutionEqualsAtom $solEqAtom - feasibility $feas* - optimality $opt - vconditionElimination $vcondElim*) => do - let atomData ← liftTermElabM do - let curv ← elabCurvature curv.raw - let (argDecls, argKinds) ← elabArgKinds args.rawImpl - let (expr, bodyTy) ← elabExpr expr.raw argDecls - let (bconds, _) ← elabBConds argDecls bconds.rawImpl - let (vconds, vcondMap) ← elabVConditions argDecls vconds.rawImpl - let (impVars, impVarMap) ← elabImpVars argDecls impVars.rawImpl - let impObj ← elabImpObj argDecls impVars impObj.raw bodyTy - let impConstrs ← elabImpConstrs argDecls impVars bconds impConstrs.rawImpl - let sols ← elabSols argDecls impVars impVarMap sols.rawImpl - let solEqAtom ← elabSolEqAtom argDecls vconds impObj sols expr solEqAtom.raw - let feas ← elabFeas argDecls vconds bconds impConstrs sols feas.rawImpl - let opt ← elabOpt curv argDecls expr bconds impVars impObj impConstrs argKinds opt.raw - let vcondElim ← elabVCondElim curv argDecls bconds vconds vcondMap impVars impConstrs argKinds vcondElim.rawImpl - - let atomData := { - id := id.getId - curvature := curv - expr := expr - argKinds := argKinds - bconds := bconds - vconds := vconds - impVars := impVars - impObjFun := impObj - impConstrs := impConstrs.map (·.snd) - solution := sols - solEqAtom := solEqAtom - feasibility := feas - optimality := opt - vcondElim := vcondElim - } - return atomData - - let atomData ← canonAtomData atomData.curvature atomData--Curvature.Affine atomData - let atomData ← addAtomDataDecls id.getId atomData - - liftTermElabM do - trace[Meta.debug] "Add atom: {atomData}" - addAtom $ AtomData.graph atomData -| _ => throwUnsupportedSyntax - -/-- -/ -def mapNonConstant (xs : Array Expr) (argKinds : Array ArgKind) (f : Expr → TermElabM Expr) : - TermElabM (Array Expr) := - (Array.zip xs argKinds).mapM fun (x, kind) => do - if kind == ArgKind.Constant - then return x - else return ← f x - -/-- -/ -def elabHom (argDecls : Array LocalDecl) (expr : Expr) (argKinds : Array ArgKind) (stx : Syntax) : TermElabM Expr := do - withExistingLocalDecls argDecls.toList do - withLocalDeclD `κ (mkConst ``Real) fun κ => do - let xs := argDecls.map (mkFVar ·.fvarId) - let zero := mkAppNBeta expr $ ← mapNonConstant xs argKinds - fun x => do return ← mkNumeral (← inferType x) 0 - let lhs ← mkAdd - (← mkAppM ``HSMul.hSMul #[κ, mkAppNBeta expr xs]) - zero - let rhs ← mkAdd - (mkAppNBeta expr $ ← mapNonConstant xs argKinds - fun x => mkAppM ``HSMul.hSMul #[κ, x]) - (← mkAppM ``HSMul.hSMul #[κ, zero]) - let ty ← mkEq lhs rhs - let hom ← Elab.Term.elabTermAndSynthesizeEnsuringType stx (some ty) - return ← mkLambdaFVars xs $ ← mkLambdaFVars #[κ] hom - -def elabAdd (argDecls : Array LocalDecl) (expr : Expr) (argKinds : Array ArgKind) (stx : Syntax) : TermElabM Expr := do - withExistingLocalDecls argDecls.toList do - let xs := argDecls.map (mkFVar ·.fvarId) - withCopyOfNonConstVars xs argKinds fun newYs ys => do - let zero := mkAppNBeta expr $ ← mapNonConstant xs argKinds - fun x => do return ← mkNumeral (← inferType x) 0 - let lhs ← mkAdd (mkAppNBeta expr xs) (mkAppNBeta expr ys) - let rhs ← mkAdd - (mkAppNBeta expr $ - ← (Array.zip argKinds (Array.zip xs ys)).mapM fun (kind, x, y) => do - if kind == ArgKind.Constant - then return x - else mkAdd x y) - (zero) - let ty ← mkEq lhs rhs - let add ← Elab.Term.elabTermAndSynthesizeEnsuringType stx (some ty) - return ← mkLambdaFVars xs $ ← mkLambdaFVars newYs add + | `(declare_atom $id [ $curv ] $args* : $expr := + bconditions $bconds* + vconditions $vconds* + implementationVars $impVars* + implementationObjective $impObj + implementationConstraints $impConstrs* + solution $sols* + solutionEqualsAtom $solEqAtom + feasibility $feas* + optimality $opt + vconditionElimination $vcondElim*) => do + let atomData ← liftTermElabM do + let curv ← elabCurvature curv.raw + let (argDecls, argKinds) ← elabArgKinds args.rawImpl + let (expr, bodyTy) ← elabExpr expr.raw argDecls + let (bconds, _) ← elabBConds argDecls bconds.rawImpl + let (vconds, vcondMap) ← elabVConditions argDecls vconds.rawImpl + let (impVars, impVarMap) ← elabImpVars argDecls impVars.rawImpl + let impObj ← elabImpObj argDecls impVars impObj.raw bodyTy + let impConstrs ← elabImpConstrs argDecls impVars bconds impConstrs.rawImpl + let sols ← elabSols argDecls impVars impVarMap sols.rawImpl + let solEqAtom ← elabSolEqAtom argDecls vconds impObj sols expr solEqAtom.raw + let feas ← elabFeas argDecls vconds bconds impConstrs sols feas.rawImpl + let opt ← elabOpt curv argDecls expr bconds impVars impObj impConstrs argKinds opt.raw + let vcondElim ← elabVCondElim curv argDecls bconds vconds vcondMap impVars impConstrs argKinds + vcondElim.rawImpl + + let atomData := { + id := id.getId + curvature := curv + expr := expr + argKinds := argKinds + bconds := bconds + vconds := vconds + impVars := impVars + impObjFun := impObj + impConstrs := impConstrs.map (·.snd) + solution := sols + solEqAtom := solEqAtom + feasibility := feas + optimality := opt + vcondElim := vcondElim + } + return atomData + + let atomData ← liftTermElabM <| canonAtomData atomData.curvature atomData + let atomData ← addAtomDataDecls id.getId atomData + + liftTermElabM do + trace[CvxLean.debug] "Adding atom data {atomData}" + addAtom <| AtomData.graph atomData + | _ => throwUnsupportedSyntax -/-- -/ +/-- Version of `declare_atom` for affine atoms. While the optimality proof is set up as if the atom +were concave, since it is affine, we can flip all the inequalities. We store both proofs. +TODO: Can this be simplified? -/ @[command_elab affineAtomCommand] unsafe def elabAffineAtomCommand : CommandElab -| `(declare_atom $id [ affine ] $args* : $expr := - bconditions $bconds* - homogenity $hom - additivity $add - optimality $opt) => do - let atomData ← liftTermElabM do - let (argDecls, argKinds) ← elabArgKinds args.rawImpl - let (expr, bodyTy) ← elabExpr expr.raw argDecls - let vconds := #[] - let impVars := #[] - let impObj := expr - let impConstrs := #[] - let sols := #[] - let solEqAtom ← lambdaTelescope expr fun xs body => do return ← mkLambdaFVars xs $ ← mkEqRefl body - let feas := #[] - let (bconds, _) ← elabBConds argDecls bconds.rawImpl - let hom ← elabHom argDecls expr argKinds hom.raw - check hom -- Property is not saved. This is just a sanity check. - let add ← elabAdd argDecls expr argKinds add.raw - check add -- Property is not saved. This is just a sanity check. - let optConcave ← elabOpt Curvature.Concave argDecls expr bconds impVars impObj impConstrs argKinds opt.raw - let optConvex ← - withExistingLocalDecls argDecls.toList do - let xs := argDecls.map (mkFVar ·.fvarId) - let bconds := bconds.map fun (n,c) => (n, mkAppNBeta c xs) - withLocalDeclsDNondep bconds fun as => do - withCopyOfMonoXs xs argKinds fun monoXs ys monoArgKind => do - let mut optInverted := optConcave - let mut i' := 0 - for i in [:xs.size] do - if argKinds[i]! != ArgKind.Constant ∧ argKinds[i]! != ArgKind.Neither then - optInverted := mkApp optInverted ys[i']! - i' := i' + 1 - else - optInverted := mkApp optInverted xs[i]! - optInverted := mkAppN optInverted as - for i in [:xs.size] do - if argKinds[i]! != ArgKind.Constant ∧ argKinds[i]! != ArgKind.Neither then - optInverted := mkApp optInverted xs[i]! - return ← mkLambdaFVars xs $ ← mkLambdaFVars as $ ← mkLambdaFVars ys optInverted - let opt ← mkAppM ``And.intro #[optConcave, optConvex] - let vcondElim := #[] - - let atomData := { - id := id.getId - curvature := Curvature.Affine - expr := expr - argKinds := argKinds - vconds := vconds - bconds := bconds - impVars := impVars - impObjFun := impObj - impConstrs := impConstrs.map (·.snd) - solution := sols - solEqAtom := solEqAtom - feasibility := feas - optimality := opt - vcondElim := vcondElim - } - return atomData - - let atomData ← addAtomDataDecls id.getId atomData - - liftTermElabM do - trace[Meta.debug] "Add atom: {atomData}" - addAtom $ AtomData.graph atomData -| _ => throwUnsupportedSyntax - -/-- -/ -@[command_elab coneAtomCommand] unsafe def elabConeAtomCommand : CommandElab -| `(declare_atom $id [ cone ] $args* : $expr := + | `(declare_atom $id [ affine ] $args* : $expr := + bconditions $bconds* + homogenity $hom + additivity $add optimality $opt) => do - let atomData ← liftTermElabM do - let (argDecls, argKinds) ← elabArgKinds args.rawImpl - let (expr, bodyTy) ← elabExpr expr.raw argDecls (ty := some (mkSort levelZero)) - let vconds := #[] - let impVars := #[] - let impObj := expr - let impConstrs := #[] - let sols := #[] - let solEqAtom ← lambdaTelescope expr fun xs body => - do return ← mkLambdaFVars xs $ ← mkEqRefl body - let feas := #[] - let bconds := #[] - -- TODO: Not concave.... - let opt ← elabOpt Curvature.Concave argDecls expr bconds impVars impObj impConstrs argKinds opt.raw - let vcondElim := #[] - - let atomData := { - id := id.getId - curvature := Curvature.ConvexSet --Curvature.Concave - expr := expr - argKinds := argKinds - vconds := vconds - bconds := bconds - impVars := impVars - impObjFun := impObj - impConstrs := impConstrs.map (·.snd) - solution := sols - solEqAtom := solEqAtom - feasibility := feas - optimality := opt - vcondElim := vcondElim - } - return atomData - - let atomData ← addAtomDataDecls id.getId atomData + let atomData ← liftTermElabM do + let (argDecls, argKinds) ← elabArgKinds args.rawImpl + let (expr, _bodyTy) ← elabExpr expr.raw argDecls + let vconds := #[] + let impVars := #[] + let impObj := expr + let impConstrs := #[] + let sols := #[] + let solEqAtom ← lambdaTelescope expr fun xs body => do + return ← mkLambdaFVars xs <| ← mkEqRefl body + let feas := #[] + let (bconds, _) ← elabBConds argDecls bconds.rawImpl + let hom ← elabHom argDecls expr argKinds hom.raw + check hom -- Property is not saved. This is just a sanity check. + let add ← elabAdd argDecls expr argKinds add.raw + check add -- Property is not saved. This is just a sanity check. + let optConcave ← elabOpt Curvature.Concave argDecls expr bconds impVars impObj impConstrs + argKinds opt.raw + let optConvex ← + withExistingLocalDecls argDecls.toList do + let xs := argDecls.map (mkFVar ·.fvarId) + let bconds := bconds.map fun (n,c) => (n, mkAppNBeta c xs) + withLocalDeclsDNondep bconds fun as => do + withCopyOfMonoXs xs argKinds fun _monoXs ys _monoArgKind => do + let mut optInverted := optConcave + let mut i' := 0 + for i in [:xs.size] do + if argKinds[i]! != ArgKind.Constant ∧ argKinds[i]! != ArgKind.Neither then + optInverted := mkApp optInverted ys[i']! + i' := i' + 1 + else + optInverted := mkApp optInverted xs[i]! + optInverted := mkAppN optInverted as + for i in [:xs.size] do + if argKinds[i]! != ArgKind.Constant ∧ argKinds[i]! != ArgKind.Neither then + optInverted := mkApp optInverted xs[i]! + return ← mkLambdaFVars (xs ++ as ++ ys) optInverted + let opt ← mkAppM ``And.intro #[optConcave, optConvex] + let vcondElim := #[] + + let atomData := { + id := id.getId + curvature := Curvature.Affine + expr := expr + argKinds := argKinds + vconds := vconds + bconds := bconds + impVars := impVars + impObjFun := impObj + impConstrs := impConstrs.map (·.snd) + solution := sols + solEqAtom := solEqAtom + feasibility := feas + optimality := opt + vcondElim := vcondElim + } + return atomData + + let atomData ← addAtomDataDecls id.getId atomData + + liftTermElabM do + trace[CvxLean.debug] "Addding atom {atomData}" + addAtom <| AtomData.graph atomData + | _ => throwUnsupportedSyntax - liftTermElabM do - trace[Meta.debug] "Add atom: {atomData}" - addAtom $ AtomData.graph atomData -| _ => throwUnsupportedSyntax +/-- Version of `declare_command` for cones. +TODO: We probably don't even need `optimality` here. Right now, all cone monotonicites are set to +"neither" to ensure that the terms in conic constraints are affine. -/ +@[command_elab coneAtomCommand] unsafe def elabConeAtomCommand : CommandElab + | `(declare_atom $id [ cone ] $args* : $expr := + optimality $opt) => do + let atomData ← liftTermElabM do + let (argDecls, argKinds) ← elabArgKinds args.rawImpl + let (expr, _bodyTy) ← elabExpr expr.raw argDecls (ty := some (mkSort levelZero)) + let vconds := #[] + let impVars := #[] + let impObj := expr + let impConstrs := #[] + let sols := #[] + let solEqAtom ← lambdaTelescope expr fun xs body => + do return ← mkLambdaFVars xs <| ← mkEqRefl body + let feas := #[] + let bconds := #[] + let opt ← elabOpt Curvature.ConvexSet argDecls expr bconds impVars impObj impConstrs argKinds + opt.raw + let vcondElim := #[] + + let atomData := { + id := id.getId + curvature := Curvature.ConvexSet + expr := expr + argKinds := argKinds + vconds := vconds + bconds := bconds + impVars := impVars + impObjFun := impObj + impConstrs := impConstrs.map (·.snd) + solution := sols + solEqAtom := solEqAtom + feasibility := feas + optimality := opt + vcondElim := vcondElim + } + return atomData + + let atomData ← addAtomDataDecls id.getId atomData + + liftTermElabM do + trace[CvxLean.debug] "Addding atom {atomData}" + addAtom <| AtomData.graph atomData + | _ => throwUnsupportedSyntax end CvxLean diff --git a/CvxLean/Tactic/DCP/AtomCmdElab.lean b/CvxLean/Tactic/DCP/AtomCmdElab.lean new file mode 100644 index 00000000..964a85b1 --- /dev/null +++ b/CvxLean/Tactic/DCP/AtomCmdElab.lean @@ -0,0 +1,462 @@ +import CvxLean.Lib.Math.Data.Real +import CvxLean.Meta.Util.Error +import CvxLean.Meta.Util.Expr +import CvxLean.Meta.Util.Meta +import CvxLean.Tactic.DCP.AtomSyntax +import CvxLean.Tactic.DCP.AtomExt + +/-! +# Helpers and elaborators for the `declare_atom` command + +This file contains helper functions and elaborators for the `declare_atom` command used in +`CvxLean/Tactic/DCP/AtomCmd.lean`. We define elaborators for every field here that set up the +proof obligations for an atom depending on its signature and implementation. + +### TODO + +* Make `vconds` available in `impObj`, in the same way that `bconds` are in `elabImpConstrs`? This + needs to be synchronised with the DCP procedure. +-/ + +namespace CvxLean + +open Lean Expr Meta Elab Term Command + +section Helpers + +/-- Helper function to add auxiliary atom declarations to the environment. In particular, an atom +declaration involves four proof obligations (which might be split into several sub-proofs). This +function names them and adds them to the environment so that the identifiers can be used to build +proofs later instead of passing around the proof terms. +TODO: This does not respect namespaces. -/ +def addAtomDataDecls (id : Name) (atomData : GraphAtomData) : CommandElabM GraphAtomData := do + if atomData.solEqAtom.hasMVar then + throwAtomDeclarationError "could not add auxiliary {atomData.solEqAtom} has a metavariable." + let solEqAtom ← addAtomDataDecl (id.mkStr "solEqAtom") atomData.solEqAtom + let optimality ← addAtomDataDecl (id.mkStr "optimality") atomData.optimality + let mut feasibility := #[] + for i in [:atomData.feasibility.size] do + feasibility := feasibility.push <| + ← addAtomDataDecl (id.mkStr (s!"feasibility{i}")) atomData.feasibility[i]! + let mut vcondElim := #[] + for i in [:atomData.vcondElim.size] do + vcondElim := vcondElim.push <| + ← addAtomDataDecl (id.mkStr (s!"vcondElim{i}")) atomData.vcondElim[i]! + return { atomData with + solEqAtom := solEqAtom + feasibility := feasibility + vcondElim := vcondElim + optimality := optimality } +where + addAtomDataDecl (name : Name) (expr : Expr) : CommandElabM Expr := do + liftCoreM <| addDecl <| .defnDecl { + name := name + levelParams := [] + type := ← liftTermElabM (do return ← inferType expr) + value := expr + hints := .regular (getMaxHeight (← getEnv) expr + 1) + safety := .safe + } + return mkConst name + +/-- Create copy of arguments with non-trivial monotonicities and add them as local declarations. +This is used in `shiftingArgs`, which, in turn, is used to set up the `optimality` and +`vconditionElimination` proof obligations. -/ +def withCopyOfMonoXs {α} [Inhabited α] (xs : Array Expr) (argKinds : Array ArgKind) + (f : Array Expr → Array Expr → Array ArgKind → TermElabM α) : TermElabM α := do + -- Determine subset of monotone arguments and their behavior. + let mut argDeclInfo : Array (Name × (Array Expr → TermElabM Expr)) := #[] + let mut monoXs := #[] + let mut monoArgKind := #[] + for i in [:xs.size] do + if argKinds[i]! != ArgKind.Constant ∧ argKinds[i]! != ArgKind.Neither then + let ty := ← inferType xs[i]! + let name := (← FVarId.getDecl xs[i]!.fvarId!).userName + argDeclInfo := argDeclInfo.push (name, fun _ => pure ty) + monoXs := monoXs.push xs[i]! + monoArgKind := monoArgKind.push argKinds[i]! + + withLocalDeclsD argDeclInfo fun ys => do + f monoXs ys monoArgKind + +/-- The `optimality` and `vconditionElimination` proof obligations statements depend on the +monotonicity of the arguments of the atom. For example, suppose we have a convex atom defined over +a doman `D` with one increasing argument and with one implementation variable `v : E`. Let +`impConstrs` be the implementation constraints of the atom, `impObj` be the implementation objective +and `expr` be the atom's expression. Then, the `optimality` proof obligation will be +`∀ x y : D, ∀ v : E, x ≥ y → impConstrs(x, v) → impObj(x, v) ≥ expr(y)`. This function allows us to +build the "`∀ x y : D, x ≥ y`" part. -/ +def shiftingArgs (curv : Curvature) (xs : Array Expr) (argKinds : Array ArgKind) + (mkConcl : Array Expr → Array Expr → TermElabM Expr) : TermElabM Expr := + withCopyOfMonoXs xs argKinds fun monoXs ys monoArgKind => do + let mut ty := ← mkConcl monoXs ys + for i' in [:monoXs.size] do + let i := monoXs.size - 1 - i' + ty ← match curvatureInArg curv monoArgKind[i]! with + | Curvature.Concave => mkArrow (← mkLe monoXs[i]! ys[i]!) ty + | Curvature.Convex => mkArrow (← mkLe ys[i]! monoXs[i]!) ty + | Curvature.ConvexSet => mkArrow (← mkLe monoXs[i]! ys[i]!) ty + | _ => throwAtomDeclarationError "invalid curvature" + return ← mkForallFVars ys ty + +/-- Similar to `withCopyOfMonoXs` but only filtering out constant arguments. This is used in affine +atom declarations. -/ +def withCopyOfNonConstVars (xs : Array Expr) (argKinds : Array ArgKind) + (f : Array Expr → Array Expr → TermElabM Expr) : TermElabM Expr := do + -- Determine subset of non-constant arguments. + let mut argDeclInfo : Array (Name × (Array Expr → TermElabM Expr)) := #[] + for i in [:xs.size] do + if argKinds[i]! != ArgKind.Constant then + let ty := ← inferType xs[i]! + let name := Name.mkSimple ((toString (← FVarId.getDecl xs[i]!.fvarId!).userName) ++ "'") + argDeclInfo := argDeclInfo.push (name, fun _ => pure ty) + + withLocalDeclsD argDeclInfo fun ys => do + let mut allYs := #[] + let mut j := 0 + for i in [:xs.size] do + -- Use old variable if constant, otherwise use new variable. + if argKinds[i]! == ArgKind.Constant then + allYs := allYs.push xs[i]! + else + allYs := allYs.push ys[j]! + j := j + 1 + return ← f ys allYs + +/-- Apply a function to non-constant arguments. Used by affine atom elaborators. -/ +def mapNonConstant (xs : Array Expr) (argKinds : Array ArgKind) (f : Expr → TermElabM Expr) : + TermElabM (Array Expr) := + (Array.zip xs argKinds).mapM fun (x, kind) => do + if kind == ArgKind.Constant then return x else return ← f x + +end Helpers + +section Elaborators + +section Signature + +/-- Elaborate curvature label in `declare_atom` into a term of type `Curvature`. -/ +def elabCurvature (curv : Syntax) : TermElabM Curvature := do + match curv.getId with + | `«convex» => return Curvature.Convex + | `«concave» => return Curvature.Concave + | `«affine» => return Curvature.Affine + | `«convex_set» => return Curvature.ConvexSet + | _ => throwAtomDeclarationError "unknown curvature {curv.getId}." + +def elabArgKindsAux : List Syntax → TermElabM (List LocalDecl × List ArgKind) + | [] => pure ([], []) + | List.cons stx stxs => do + match stx with + | `(arg_with_kind| ($id : $ty) $argkind) => do + let ty ← Term.elabTerm ty.raw none + withLocalDeclD id.getId ty fun x => do + let argkind ← + match argkind.raw with + | `(arg_kind| +) => pure ArgKind.Increasing + | `(arg_kind| -) => pure ArgKind.Decreasing + | `(arg_kind| ?) => pure ArgKind.Neither + | `(arg_kind| &) => pure ArgKind.Constant + | _ => throwAtomDeclarationError "unknown argument kind {argkind}." + let argDecl ← x.fvarId!.getDecl + let (argDecls, argKinds) ← elabArgKindsAux stxs + return (argDecl :: argDecls, argkind :: argKinds) + | _ => throwUnsupportedSyntax + +/-- Given the arguments of an atom declarations (e.g., [`(x : ℝ)+`, `(y : ℝ)-`]), elaborate them +into a list of local declarations and a list of `ArgKind` labels. -/ +def elabArgKinds (args : Array Syntax) : TermElabM (Array LocalDecl × Array ArgKind) := do + let (argDecls, argKinds) ← elabArgKindsAux args.toList + return (argDecls.toArray, argKinds.toArray) + +/-- Elaborate the atom's expression. This depends on the arguments, which we assume have been +elaborated to local declarations. -/ +def elabExpr (expr : Syntax) (argDecls : Array LocalDecl) (ty : Option Expr := none) : + TermElabM (Expr × Expr) := do + withExistingLocalDecls argDecls.toList do + let xs := argDecls.map (mkFVar ·.fvarId) + let body ← elabTermAndSynthesizeEnsuringType expr ty + let bodyTy ← inferType body + return (← mkLambdaFVars xs body, bodyTy) + +/-- Conditions are just expressions (of type `Prop`) defined using the arguments. It is useful to +have them also as a hashmap for variable condition elimination. +TODO: Some duplication here, we essentially return the same data twice. There is a similar issue +with other functions below. -/ +def elabVConditions (argDecls : Array LocalDecl) (vcondStx : Array Syntax) : + TermElabM (Array (Name × Expr) × Std.HashMap Name Expr) := do + withExistingLocalDecls argDecls.toList do + let xs := argDecls.map (mkFVar ·.fvarId) + let mut vcondMap : Std.HashMap Name Expr := {} + let mut vconds : Array (Name × Expr) := #[] + for stx in vcondStx do + let vcond ← elabTermAndSynthesizeEnsuringType stx[3] (some <| mkSort levelZero) + let vcond ← mkLambdaFVars xs vcond + vcondMap := vcondMap.insert stx[1].getId vcond + vconds := vconds.push (stx[1].getId, vcond) + return (vconds, vcondMap) + +/-- Background conditions can be elaborated exactly like variable conditions. -/ +def elabBConds := elabVConditions + +end Signature + +section GraphImplementation + +/-- The implementation variables used in the graph implementation of ana tom are given as +identifiers with a type such as `(t : ℝ)`. Here, we get the names and types of all implementation +variables. Their types might depend on the arguments, mainly with vectors or arrays where the type +is, say, `Fin n → ℝ`, and `n` is an atom argument. -/ +def elabImpVars (argDecls : Array LocalDecl) (impVarsStx : Array Syntax) : + TermElabM (Array (Name × Expr) × Std.HashMap Name Expr) := do + withExistingLocalDecls argDecls.toList do + -- `xs` are the arguments of the atom. + let xs := argDecls.map (mkFVar ·.fvarId) + let mut impVarMap : Std.HashMap Name Expr := {} + let mut impVars : Array (Name × Expr) := #[] + for stx in impVarsStx do + let impVarTy ← Elab.Term.elabTermAndSynthesizeEnsuringType stx[3] none + let impVarTy ← mkLambdaFVars xs <| impVarTy + impVars := impVars.push (stx[1].getId, impVarTy) + impVarMap := impVarMap.insert stx[1].getId impVarTy + return (impVars, impVarMap) + +/-- The implementation objective is an expression that might depend on the atom arguments and the +implementation variables. +TODO: Add background conditions. -/ +def elabImpObj (argDecls : Array LocalDecl) (impVars : Array (Name × Expr)) + (impObjStx : Syntax) (bodyTy : Expr) : TermElabM Expr := do + withExistingLocalDecls argDecls.toList do + -- `xs` are the arguments of the atom. + let xs := argDecls.map (mkFVar ·.fvarId) + -- `vs` are the implementation variables. + withLocalDeclsDNondep (impVars.map fun iv => (iv.1, mkAppNBeta iv.2 xs)) fun vs => do + let impObj ← Elab.Term.elabTermAndSynthesizeEnsuringType impObjStx (some bodyTy) + return ← mkLambdaFVars (xs ++ vs) impObj + +/-- The implementation constraints are expressions (of type `Prop`) that might depend on the atom +arguments and the implementation variables. Theyc an also use background conditions. -/ +def elabImpConstrs (argDecls : Array LocalDecl) (impVars : Array (Name × Expr)) + (bconds : Array (Name × Expr)) (impConstrStx : Array Syntax) : + TermElabM (Array (Name × Expr)) := do + withExistingLocalDecls argDecls.toList do + -- `xs` are the arguments of the atom. + let xs := argDecls.map (mkFVar ·.fvarId) + let bconds := bconds.map fun (n,c) => (n, mkAppNBeta c xs) + -- `as` are the background conditions. + withLocalDeclsDNondep bconds fun as => do + -- `vs` are the implementation variables. + withLocalDeclsDNondep (impVars.map fun iv => (iv.1, mkAppNBeta iv.2 xs)) fun vs => do + let mut impConstrs : Array (Name × Expr) := #[] + for stx in impConstrStx do + let impConstr ← Elab.Term.elabTermAndSynthesizeEnsuringType stx[3] (mkSort levelZero) + let impConstr ← mkLambdaFVars (xs ++ as ++ vs) impConstr + trace[CvxLean.debug] "`elabImpConstrs` constraint type {← inferType impConstr}." + impConstrs := impConstrs.push (stx[1].getId, impConstr) + return impConstrs + +end GraphImplementation + +section ProofObligations + +/-- For every implementation variable, they user must provide a solution as an expression using +the atom's arguments. These expressions are crucial to build the forward map in the equivalence +relation. -/ +def elabSols (argDecls : Array LocalDecl) (impVars : Array (Name × Expr)) + (impVarMap : Std.HashMap Name Expr) (solStx : Array Syntax) : TermElabM (Array Expr) := do + withExistingLocalDecls argDecls.toList do + -- `xs` are the arguments of the atom. + let xs := argDecls.map (mkFVar ·.fvarId) + let impVarNames := impVars.map (·.fst) + let mut solMap : Std.HashMap Name Expr := {} + for stx in solStx do + let id ← match impVarMap.find? stx[1].getId with + | some id => pure id + | none => throwAtomDeclarationError "unknown variable in solution {stx[1].getId}." + let ty := mkAppNBeta id xs + solMap := solMap.insert stx[1].getId <| + ← Elab.Term.elabTermAndSynthesizeEnsuringType stx[3] (some ty) + let sols ← impVarNames.mapM + fun n => match solMap.find? n with + | some sol => pure sol + | none => throwAtomDeclarationError "solution not found for {n}." + let sols ← sols.mapM (do return ← mkLambdaFVars xs ·) + return sols + +/-- Given the implementation objecitve and the user-defined solutions, the first proof obligation +is to show that instantiating the implementation variables in the implementaiton objective with the +solution expressions needs to be **equal** to the atom's expression. This is used later on to prove +the needed properties of the forward map. -/ +def elabSolEqAtom (argDecls : Array LocalDecl) (vconds : Array (Name × Expr)) + (impObj : Expr) (sols : Array Expr) + (expr : Expr) (stx : Syntax) : TermElabM Expr := do + withExistingLocalDecls argDecls.toList do + -- `xs` are the arguments of the atom. + let xs := argDecls.map (mkFVar ·.fvarId) + let vconds := vconds.map fun (n,c) => (n, mkAppNBeta c xs) + withLocalDeclsDNondep vconds fun cs => do + let body := mkAppNBeta expr xs + let impObj := mkAppNBeta impObj xs + let sols := sols.map (mkAppNBeta · xs) + let impObj' := convertLambdasToLets impObj sols + let ty ← mkEq impObj' body + trace[CvxLean.debug] "`elabSolEqAtom` ensuring type {ty}." + let solEqAtom ← Elab.Term.elabTermAndSynthesizeEnsuringType stx (some ty) + mkLambdaFVars (xs ++ cs) solEqAtom + +/-- Here, we set up the second proof obligation, or rather list of proof obligations. For every +implmentation constraints, we must show that the solutions satisfy them. This function sets up +the proof goal for each constraint. This is used later on to prove the needed properties of the +forward map. -/ +def elabFeas (argDecls : Array LocalDecl) (vconds bconds : Array (Name × Expr)) + (impConstrs : Array (Name × Expr)) (sols : Array Expr) (stx : Array Syntax) : + TermElabM (Array Expr) := do + withExistingLocalDecls argDecls.toList do + -- `xs` are the arguments of the atom. + let xs := argDecls.map (mkFVar ·.fvarId) + let bconds := bconds.map fun (n, c) => (n, mkAppNBeta c xs) + -- `as` are the background conditions. + withLocalDeclsDNondep bconds fun as => do + let vconds := vconds.map fun (n, c) => (n, mkAppNBeta c xs) + -- `cs` are the variable conditions. + withLocalDeclsDNondep vconds fun cs => do + let impConstrs := impConstrs.map fun (n, c) => (n, mkAppNBeta c xs) + let impConstrs := impConstrs.map fun (n, c) => (n, mkAppNBeta c as) + let sols := sols.map (mkAppNBeta · xs) + let mut feas := #[] + for i in [:impConstrs.size] do + if (stx[i]!)[1]!.getId != impConstrs[i]!.1 then + throwAtomDeclarationError + "expected feasibility {impConstrs[i]!.1}, found {(stx[i]!)[1]!}." + let ty := convertLambdasToLets impConstrs[i]!.2 sols + let f ← Elab.Term.elabTermAndSynthesizeEnsuringType (stx[i]!)[3]! (some ty) + let f ← mkLambdaFVars (xs ++ as ++ cs) f + feas := feas.push f + return feas + +/-- This proof obligation has to do with how the implementation objective bounds the atom's +expression in the appropriate monotone context. It depends on the curvature of the atom. It has the +following shape. Let `x₁, ..., xₙ` and `y₁, ..., yₙ` be variables in the atom's domain (we only need +to consider the components of the domain where the atom's function has non-trivial monotonicity +(increasing or decreasing)). Let `v₁, ..., vₘ` be the implementation variables. We need to show: +``` +h₁ : y₁ ~ x₁ +... +hₙ : yₙ ~ xₙ +h : impConstrs(x₁, ..., xₙ, v₁, ..., vₘ) +⊢ impObj(x₁, ..., xₙ, v₁, ..., vₘ) ~ expr(y₁, ..., yₙ) +``` +Each `~` is replaced by the appropriate relation depending on the atom's signature. +This is used later on to prove the required properties of the backward map. -/ +def elabOpt (curv : Curvature) (argDecls : Array LocalDecl) (expr : Expr) + (bconds : Array (Name × Expr)) (impVars : Array (Name × Expr)) (impObj : Expr) + (impConstrs : Array (Name × Expr)) (argKinds : Array ArgKind) (stx : Syntax) : + TermElabM Expr := do + withExistingLocalDecls argDecls.toList do + -- `xs` are the arguments of the atom. + let xs := argDecls.map (mkFVar ·.fvarId) + let bconds := bconds.map fun (n,c) => (n, mkAppNBeta c xs) + -- `as` are the background conditions. + withLocalDeclsDNondep bconds fun as => do + -- `vs` are the implementation variables. + withLocalDeclsDNondep (impVars.map fun iv => (iv.1, mkAppNBeta iv.2 xs)) fun vs => do + let impConstrs := impConstrs.map fun (n,c) => (n, mkAppNBeta c xs) + let impConstrs := impConstrs.map fun (n,c) => (n, mkAppNBeta c as) + let impConstrs := impConstrs.map fun (n,c) => (n, mkAppNBeta c vs) + -- `is` are the implementation constraints. + withLocalDeclsDNondep impConstrs fun is => do + let impObj := mkAppNBeta (mkAppNBeta impObj xs) vs + let ty ← shiftingArgs curv xs argKinds fun monoXs ys => + let body := mkAppNBeta expr xs + let body := body.replaceFVars monoXs ys + match curv with + | Curvature.Concave => return ← mkLe impObj body + | Curvature.Convex => return ← mkLe body impObj + | Curvature.ConvexSet => return ← mkLe impObj body + | _ => throwAtomDeclarationError + "invalid curvature {curv} found when setting up optimality proof obligation." + trace[CvxLean.debug] "`elabOpt` ensuring {ty}." + let opt ← Elab.Term.elabTermAndSynthesizeEnsuringType stx (some ty) + trace[CvxLean.debug] "`elabOpt` proof term type {← inferType opt}." + mkLambdaFVars (xs ++ as ++ vs ++ is) opt + +/-- For every variable condition, we must show that it can be proved from the implementation +constraints. This function sets up such goals. This is used later on to prove the required +properties of the backward map. -/ +def elabVCondElim (curv : Curvature) (argDecls : Array LocalDecl) + (bconds vconds : Array (Name × Expr)) (vcondMap : Std.HashMap Name Expr) + (impVars : Array (Name × Expr)) (impConstrs : Array (Name × Expr)) (argKinds : Array ArgKind) + (stx : Array Syntax) : TermElabM (Array Expr) := do + withExistingLocalDecls argDecls.toList do + let xs := argDecls.map (mkFVar ·.fvarId) + let impConstrs := impConstrs.map fun (n,c) => (n, mkAppNBeta c xs) + let bconds := bconds.map fun (n,c) => (n, mkAppNBeta c xs) + -- `as` are the background conditions. + withLocalDeclsDNondep bconds fun as => do + let impConstrs := impConstrs.map fun (n,c) => (n, mkAppNBeta c as) + -- `vs` are the implementation variables. + withLocalDeclsDNondep (impVars.map fun iv => (iv.1, mkAppNBeta iv.2 xs)) fun vs => do + let impConstrs := impConstrs.map fun (n,c) => (n, mkAppNBeta c vs) + -- `is` are the implementation constraints. + withLocalDeclsDNondep impConstrs fun is => do + let mut vcondElimMap : Std.HashMap Name Expr := {} + for i in [:stx.size] do + let ty ← shiftingArgs curv xs argKinds fun monoXs ys => do + let id ← match vcondMap.find? (stx[i]!)[1]!.getId with + | some id => pure id + | none => + throwAtomDeclarationError "unknown variable condition {(stx[i]!)[1]!.getId}." + let body := mkAppNBeta id xs + let body := body.replaceFVars monoXs ys + return body + let vcondElim ← Elab.Term.elabTermAndSynthesizeEnsuringType (stx[i]!)[3]! (some <| ty) + let vcondElim ← mkLambdaFVars (xs ++ vs ++ is) vcondElim + vcondElimMap := vcondElimMap.insert (stx[i]!)[1]!.getId vcondElim + return ← vconds.mapM + fun (n, _) => match vcondElimMap.find? n with + | some vcond => pure vcond + | none => throwAtomDeclarationError "variable condition {n} not found." + +end ProofObligations + +section AffineAtoms + +/-- Set up proof of homogeneity. Sanity check for affine atoms. -/ +def elabHom (argDecls : Array LocalDecl) (expr : Expr) (argKinds : Array ArgKind) (stx : Syntax) : + TermElabM Expr := do + withExistingLocalDecls argDecls.toList do + withLocalDeclD `κ (mkConst ``Real) fun κ => do + let xs := argDecls.map (mkFVar ·.fvarId) + let zero := mkAppNBeta expr <| ← mapNonConstant xs argKinds + fun x => do return ← mkNumeral (← inferType x) 0 + let lhs ← mkAdd (← mkAppM ``HSMul.hSMul #[κ, mkAppNBeta expr xs]) zero + let rhs ← mkAdd + (mkAppNBeta expr <| ← mapNonConstant xs argKinds fun x => mkAppM ``HSMul.hSMul #[κ, x]) + (← mkAppM ``HSMul.hSMul #[κ, zero]) + let ty ← mkEq lhs rhs + let hom ← Elab.Term.elabTermAndSynthesizeEnsuringType stx (some ty) + return ← mkLambdaFVars xs <| ← mkLambdaFVars #[κ] hom + +/-- Set up proof of additivity. Sanity check for affine atoms. -/ +def elabAdd (argDecls : Array LocalDecl) (expr : Expr) (argKinds : Array ArgKind) (stx : Syntax) : + TermElabM Expr := do + withExistingLocalDecls argDecls.toList do + let xs := argDecls.map (mkFVar ·.fvarId) + withCopyOfNonConstVars xs argKinds fun newYs ys => do + let zero := mkAppNBeta expr <| ← mapNonConstant xs argKinds + fun x => do return ← mkNumeral (← inferType x) 0 + let lhs ← mkAdd (mkAppNBeta expr xs) (mkAppNBeta expr ys) + let rhs ← mkAdd + (mkAppNBeta expr <| + ← (Array.zip argKinds (Array.zip xs ys)).mapM fun (kind, x, y) => do + if kind == ArgKind.Constant then return x else mkAdd x y) + zero + let ty ← mkEq lhs rhs + let add ← Elab.Term.elabTermAndSynthesizeEnsuringType stx (some ty) + return ← mkLambdaFVars xs <| ← mkLambdaFVars newYs add + +end AffineAtoms + +end Elaborators + +end CvxLean diff --git a/CvxLean/Tactic/DCP/AtomCmdMultiLevel.lean b/CvxLean/Tactic/DCP/AtomCmdMultiLevel.lean new file mode 100644 index 00000000..61ff83bc --- /dev/null +++ b/CvxLean/Tactic/DCP/AtomCmdMultiLevel.lean @@ -0,0 +1,270 @@ +import CvxLean.Meta.Util.Expr +import CvxLean.Meta.Util.Meta +import CvxLean.Meta.Util.Error +import CvxLean.Tactic.DCP.AtomCurvRule +import CvxLean.Tactic.DCP.AtomExt +import CvxLean.Tactic.DCP.AtomCmdElab +import CvxLean.Tactic.DCP.DCP + +/-! +# Support for multi-level atom declarations + +This file defines the method `canonAtomData` which allows atom declarations to not be in conic form, +as long as they can be canonized into conic using the existing atoms at that point. This allows for +more readable atom declarations. + +Note that this means that what is written after `declare_atom` is not necessarily what is stored in +the atom library, since the atom declaration might need to be canonized into conic form. +-/ + +namespace CvxLean + +open Lean Meta Elab Expr + +section Helpers + +/-- Count the number of arguments with non-trivial monotonicities. -/ +def getMonoArgsCount (curv : Curvature) (argKinds : Array ArgKind) : Nat := + argKinds.foldl (fun acc argKind => + match curvatureInArg curv argKind with + | Curvature.Concave => 1 + acc + | Curvature.Convex => 1 + acc + | _ => acc) 0 + +end Helpers + +section MultiLevel + +/-- Use the DCP procedure on the optimization problem represented by the graph implementation. This +is the first step, we only have a raw processed atom tree at this point. -/ +def processGraphImplementation (objCurv : Curvature) (atomData : GraphAtomData) (xs : Array Expr) : + MetaM (DCP.ProcessedAtomTree × Array LocalDecl) := do + -- First, we set up the optimization problem corresponding to the graph implementation as + -- follows. The objective function corresponds to the implementation objective, the + -- constraints correspond to the implmenetation constraints, and the optimization variables + -- correspond to the implementation variables. + let (objFun, constraints, originalVarsDecls) ← do + let impVars := atomData.impVars.map fun (n, ty) => (n, mkAppNBeta ty xs) + withLocalDeclsDNondep impVars fun vs => do + let vsDecls ← vs.mapM fun v => v.fvarId!.getDecl + let bconds := atomData.bconds.map fun (n,c) => (n, mkAppNBeta c xs) + withLocalDeclsDNondep bconds fun bs => do + let originalVarsDecls := vsDecls + let objFun := mkAppNBeta (mkAppNBeta atomData.impObjFun xs) vs + let constraints := atomData.impConstrs.map + fun c => (`_, mkAppNBeta (mkAppNBeta (mkAppNBeta c xs) bs) vs) + return (objFun, constraints, originalVarsDecls) + + -- Thge atom arguments are not optimization variables, but still need to be seen as affine + -- expressions by the DCP procedure. Here, we collect all the non-constant arguments as the + -- free variable expressions that may appear in the graph implementation. + let xsAffPre := xs.map fun x => x.fvarId! + let mut xsAff := #[] + for i in [:xs.size] do + if atomData.argKinds[i]! != ArgKind.Constant then + xsAff := xsAff.push xsAffPre[i]! + + -- This is the key step where we ask the DCP procedure to canonize the optimization problem + -- corresponding to the graph implementation. We indicate that the some variables coming from + -- the atom arguments might be affine. We also include the background conditions, which may + -- be needed to canonize the problem. + let bconds := atomData.bconds.map fun (n, c) => (n, mkAppNBeta c xs) + let pat ← withLocalDeclsDNondep bconds fun bs => do + let bcondsDecls ← bs.mapM (·.fvarId!.getDecl) + DCP.mkProcessedAtomTree objCurv objFun constraints.toList originalVarsDecls + (extraVars := xsAff) (extraDecls := bcondsDecls) + + -- We return the result from DCP, and the optimization variables (as local declarations), which, + -- recall, are the implementation variables. + return (pat, originalVarsDecls) + +/-- The second step is, given the processed atom tree and the raw atom data, we need to adjust all +the fields and proofs. -/ +def adjustCanonAtomData (objCurv : Curvature) (atomData : GraphAtomData) + (pat : DCP.ProcessedAtomTree) (originalVarsDecls : Array LocalDecl) (xs : Array Expr) : + MetaM GraphAtomData := do + withExistingLocalDecls originalVarsDecls.toList do + withExistingLocalDecls pat.newVarDecls do + withExistingLocalDecls pat.newConstrVarsArray.toList do + -- `vs` are the variables already present in the uncanonized graph implementation, i.e., + -- explicitly written by the user in the atom declaration. + let vs := originalVarsDecls.map (mkFVar ·.fvarId) + + -- `zs` are the implementation variables to be added that come from the canonization of + -- the atom declaration iteself. + let zs := pat.newVarDecls.toArray.map (mkFVar ·.fvarId) + + -- Uncanonized atom conditions. + let vconds := atomData.vconds.map fun (n, c) => (n, mkAppNBeta c xs) + let bconds := atomData.bconds.map fun (n, c) => (n, mkAppNBeta c xs) + + /- 1. Adjust the implementation variables, essentially (`vs` and `zs`). -/ + + let adjustedImpVars := atomData.impVars.append (← pat.newVarDecls.toArray.mapM + fun decl => do return (decl.userName, ← mkLambdaFVars xs decl.type)) + + let _ ← adjustedImpVars.mapM (fun (_, ty) => check ty) + trace[CvxLean.debug] "adjustedImpVars: {adjustedImpVars}" + + /- 2. Adjust the implementation objective. -/ + + let adjustedImpObjFun ← mkLambdaFVars (xs ++ vs ++ zs) pat.canonExprs.objFun.val + + trace[CvxLean.debug] "adjustedImpObjFun: {adjustedImpObjFun}" + + /- 3. Adjust the implementation constraints. -/ + + -- Canonized implementation constraints (in conic form), filter out those marked as + -- vconds. + let canonConstrs := pat.canonExprs.constrs.map Tree.val |>.filterIdx + (fun i => !pat.isVCond[i]!) + + -- DCP might have also added new constraints. + let newConstrs := pat.newConstrs + + -- The adjusted constraints need to take into account the `zs`. Recall that implementation + -- constraints may also use background conditions. + let adjustedImpConstrs ← (canonConstrs ++ newConstrs).mapM (fun c => do + withLocalDeclsDNondep bconds fun as => mkLambdaFVars (xs ++ as ++ vs ++ zs) c) + + let _ ← adjustedImpConstrs.mapM check + trace[CvxLean.debug] "adjustedImpConstrs: {adjustedImpConstrs}" + + /- 4. Adjust the solution. -/ + + -- Solutions to `vs` (in terms of `xs`). + let vsSol := atomData.solution.map fun s => mkAppNBeta s xs + + -- Substitute into the canonization forward map the solutions to `vs`, which gives us the + -- solutions of `zs` as expressions in `xs`. + let zsSolInTermsOfXs ← pat.forwardImagesNewVars.mapM (fun e => mkLambdaFVars xs + (e.replaceFVars vs vsSol)) + + -- The adjusted solutions, are the solutions to both `vs` abd `zs` as expressions in `xs`. + let adjustedSolution := atomData.solution ++ zsSolInTermsOfXs + + let _ ← adjustedSolution.mapM check + trace[CvxLean.debug] "adjustedSolution: {adjustedSolution}" + + /- 5. Adjust the feasibility proofs. -/ + + let solEqAtomProofs := pat.solEqAtom.constrs.map Tree.val + + -- First, we change all the existing feasibility proofs to take into account the solutions + -- for `zs`. + let mut oldFeasibilityAdjusted := #[] + for i in [:atomData.feasibility.size] do + let feas := atomData.feasibility[i]! + let feasXs := mkAppNBeta feas xs + let adjustedFeas ← + withLocalDeclsDNondep vconds fun cs => do + withLocalDeclsDNondep bconds fun bs => do + let feasXsVconds := mkAppNBeta feasXs cs + let feasXsConds := mkAppNBeta feasXsVconds bs + let proofAdjusted := solEqAtomProofs[i]!.replaceFVars vs vsSol + check proofAdjusted + let adjustedFeas ← mkAppM ``Eq.mpr #[proofAdjusted, feasXsConds] + + mkLambdaFVars (xs ++ cs ++ bs) adjustedFeas + oldFeasibilityAdjusted := oldFeasibilityAdjusted.push adjustedFeas + + -- The new feasibility proofs are in terms of `vs` so we need to replace them by their + -- solutions. + let newConstrProofs := pat.feasibility.fold #[] fun acc fs => fs.fold acc Array.append + let newFeasibility ← newConstrProofs.mapM (fun e => + withLocalDeclsDNondep vconds fun cs => do + withLocalDeclsDNondep bconds fun bs => do + mkLambdaFVars (xs ++ cs ++ bs) (e.replaceFVars vs vsSol)) + + let adjustedFeasibility := oldFeasibilityAdjusted ++ newFeasibility + + let _ ← adjustedFeasibility.mapM check + trace[CvxLean.debug] "adjustedFeasibility: {adjustedFeasibility}" + + /- 6. Adjust the optimality proof. -/ + + let constrsOptimalityProofs := pat.optimality.constrs.map Tree.val + let objFunOptimalityProof := pat.optimality.objFun.val + let adjustedOptimality ← + withLocalDeclsDNondep bconds fun bs => do + -- Apply variables, background conditions, and implementation variables to the + -- user-provided optimality proof. + let optimalityXsBConds := mkAppN atomData.optimality (xs ++ bs ++ vs) + withLocalDeclsDNondep (canonConstrs.map (fun rc => (`_, rc))) fun cs => do + -- Then, apply all constraints. + let mut optimalityAfterCanon := optimalityXsBConds + for i in [:canonConstrs.size] do + let c := mkApp constrsOptimalityProofs[i]! cs[i]! + optimalityAfterCanon := mkApp optimalityAfterCanon c + -- At this point, we have the monotonicity conditions implying the optimality + -- condition. + -- Then, adjust the condition using `objFunOptimalityProof`. We need to deal with + -- the monotonicity conditions carefully as well. + let monoArgsCount := getMonoArgsCount objCurv atomData.argKinds + let optCondAdjusted ← + lambdaTelescope (← whnf optimalityAfterCanon) <| fun optArgs e => do + -- Every extra argument has an extra condition, e.g. x', x ≤ x. + let monoArgs := optArgs[:2 * monoArgsCount] + let optCondition ← mkLambdaFVars optArgs[2 * monoArgsCount:] e + let atomRange ← inferType (mkAppNBeta atomData.expr xs) + -- Obtain the new condition by transitivty. + let newCond ← + if atomData.curvature == Curvature.Convex then + mkAppOptM ``le_trans #[ + atomRange, none, none, none, none, + optCondition, objFunOptimalityProof] + else + mkAppOptM ``le_trans #[ + atomRange, none, none, none, none, + objFunOptimalityProof, optCondition] + mkLambdaFVars monoArgs newCond + + let ds := pat.newConstrVarsArray.map (mkFVar ·.fvarId) + mkLambdaFVars (xs ++ bs ++ vs ++ zs ++ cs ++ ds) optCondAdjusted + + check adjustedOptimality + trace[CvxLean.debug] "adjustedOptimality: {adjustedOptimality}" + + /- 7. Adjust vconditionElimination -/ + + let mut adjustedVCondElim := #[] + for vcondElim in atomData.vcondElim do + let newVCondElim := mkAppN vcondElim (xs ++ vs) + let newVCondElim ← + withLocalDeclsDNondep (canonConstrs.map (fun rc => (`_, rc))) fun cs => do + let mut newVCondElim := newVCondElim + for i in [:canonConstrs.size] do + let c := mkApp constrsOptimalityProofs[i]! cs[i]! + newVCondElim := mkApp newVCondElim c + let ds := pat.newConstrVarsArray.map (mkFVar ·.fvarId) + mkLambdaFVars (xs ++ vs ++ zs ++ cs ++ ds) newVCondElim + adjustedVCondElim := adjustedVCondElim.push newVCondElim + + let _ ← adjustedVCondElim.mapM check + trace[CvxLean.debug] "adjustedVCondElim: {adjustedVCondElim}" + + /- 8. Return the adjusted atom data. -/ + + return { atomData with + impVars := adjustedImpVars + impObjFun := adjustedImpObjFun + impConstrs := adjustedImpConstrs + solution := adjustedSolution + feasibility := adjustedFeasibility + optimality := adjustedOptimality + vcondElim := adjustedVCondElim + } + +/-- Given raw atom data as defined in the user (which might not be in conic form), generate new atom +data in conic form. This function calls `processGraphImplementation` to apply DCP on the +optimization problem represented by the atom's implementation and `adjustCanonAtomData` to fix all +the atom data, including the proofs. -/ +def canonAtomData (objCurv : Curvature) (atomData : GraphAtomData) : MetaM GraphAtomData := do + -- `xs` are the arguments of the atom. + lambdaTelescope atomData.expr fun xs _ => do + let (pat, originalVarsDecls) ← processGraphImplementation objCurv atomData xs + adjustCanonAtomData objCurv atomData pat originalVarsDecls xs + +end MultiLevel + +end CvxLean diff --git a/CvxLean/Tactic/DCP/AtomCurvRule.lean b/CvxLean/Tactic/DCP/AtomCurvRule.lean new file mode 100644 index 00000000..769fa3ea --- /dev/null +++ b/CvxLean/Tactic/DCP/AtomCurvRule.lean @@ -0,0 +1,59 @@ +import Lean + +/-! +# Curvature rules + +This file contains the definition of the `Curvature` and `ArgKind` types, as well as the function +`curvatureInArg` that calculates the curvature of an atom depending on the monotonicity of its +arguments. This corresponds to the DCP composition rules. It is used in a crucial way to make atom +trees by the `dcp` tactic. +-/ + +namespace CvxLean + +open Lean + +/-- An atom can be affine, concave or convex. -/ +inductive Curvature where + | Constant | Affine | Concave | Convex | AffineSet | ConvexSet +deriving BEq, Inhabited + +instance : ToMessageData Curvature where + toMessageData + | Curvature.Constant => "constant" + | Curvature.Affine => "affine" + | Curvature.Concave => "concave" + | Curvature.Convex => "convex" + | Curvature.AffineSet => "affine_set" + | Curvature.ConvexSet => "convex_set" + +/-- Arguments of an atom are increasing (+), decreasing (-), constant (&) or neither (?). -/ +inductive ArgKind where + | Increasing | Decreasing | Neither | Constant +deriving BEq, Inhabited + +instance : ToMessageData ArgKind where + toMessageData + | ArgKind.Increasing => "increasing" + | ArgKind.Decreasing => "decreasing" + | ArgKind.Neither => "neither" + | ArgKind.Constant => "constant" + +/-- Given an expression of the form `f(e₁, ..., eₙ)`, where `f` corresponds to an atom with +curvature `c` and its `i`th argument has monotonicity (or, rather, `ArgKind`) `aᵢ`, this function +calculates the **expected** curvature of `eᵢ` from `c` and `aᵢ`. -/ +def curvatureInArg : Curvature → ArgKind → Curvature + | _, ArgKind.Constant => Curvature.Constant + | Curvature.Concave, ArgKind.Increasing => Curvature.Concave + | Curvature.Concave, ArgKind.Decreasing => Curvature.Convex + | Curvature.Convex, ArgKind.Increasing => Curvature.Convex + | Curvature.Convex, ArgKind.Decreasing => Curvature.Concave + -- This requires some explanation. The only set atom that allows non-affine arguments is `≤`. The + -- rest, equality and cones, require affine arguments. We say that an expression `x ≤ y` is + -- decreasing in `x` in the following sense: if `x' ≤ x` then `x ≤ y → x' ≤ y`, where `→` is seen + -- as the order relation in `Prop`. + | Curvature.ConvexSet, ArgKind.Increasing => Curvature.Concave + | Curvature.ConvexSet, ArgKind.Decreasing => Curvature.Convex + | _, _ => Curvature.Affine + +end CvxLean diff --git a/CvxLean/Tactic/DCP/AtomExt.lean b/CvxLean/Tactic/DCP/AtomExt.lean index 12672d64..52dd36e0 100644 --- a/CvxLean/Tactic/DCP/AtomExt.lean +++ b/CvxLean/Tactic/DCP/AtomExt.lean @@ -1,58 +1,19 @@ import CvxLean.Tactic.DCP.DiscrTree +import CvxLean.Tactic.DCP.AtomCurvRule /-! -# Atom library (extension) +# Atom library (environment extension) -This file defines the `GraphAtomData` type, which is what is stored when declaring an atom. This -includes defining `Curvature` and `ArgKind`, and the composition rule using them. +This file defines the `GraphAtomData` type, which is what is stored when declaring an atom. Its +fields are curvature and monotonicity attributes and several Lean expressions. -The enviromnet extesion is also defined here. +The enviroment extesion is also defined here. We store a discrimination tree built from all the +atoms in the library based on the atom's expression (which is the function that the atom describes). -/ namespace CvxLean -open Lean Lean.Meta - -/-- An atom can be affine, concave or convex. -/ -inductive Curvature where - | Constant | Affine | Concave | Convex | AffineSet | ConvexSet -deriving BEq, Inhabited - -/-- Arguments of an atom are increasing (+), decreasing (-), constant (&) or neither (?). -/ -inductive ArgKind where - | Increasing | Decreasing | Neither | Constant -deriving BEq, Inhabited - -/-- Caclulate curvature depending on monotonicity. -/ -def curvatureInArg : Curvature → ArgKind → Curvature - | _, ArgKind.Constant => Curvature.Constant - | Curvature.Concave, ArgKind.Increasing => Curvature.Concave - | Curvature.Concave, ArgKind.Decreasing => Curvature.Convex - | Curvature.Convex, ArgKind.Increasing => Curvature.Convex - | Curvature.Convex, ArgKind.Decreasing => Curvature.Concave - -- This requires some explanation. The only set atom that allows non-affine arguments is `≤`. The - -- rest, equality and cones, require affine arguments. We say that an expression `x ≤ y` is - -- decreasing in `x` in the following sense: if `x' ≤ x` then `x ≤ y → x' ≤ y`, where `→` is seen - -- as the order relation in `Prop`. - | Curvature.ConvexSet, ArgKind.Increasing => Curvature.Concave - | Curvature.ConvexSet, ArgKind.Decreasing => Curvature.Convex - | _, _ => Curvature.Affine - -instance : ToMessageData Curvature where - toMessageData - | Curvature.Constant => "constant" - | Curvature.Affine => "affine" - | Curvature.Concave => "concave" - | Curvature.Convex => "convex" - | Curvature.AffineSet => "affine_set" - | Curvature.ConvexSet => "convex_set" - -instance : ToMessageData ArgKind where - toMessageData - | ArgKind.Increasing => "increasing" - | ArgKind.Decreasing => "decreasing" - | ArgKind.Neither => "neither" - | ArgKind.Constant => "constant" +open Lean Meta section AtomData @@ -62,9 +23,9 @@ structure GraphAtomData where curvature : Curvature expr : Expr argKinds : Array ArgKind - bconds : Array (Lean.Name × Expr) - vconds : Array (Lean.Name × Expr) - impVars : Array (Lean.Name × Expr) + bconds : Array (Name × Expr) + vconds : Array (Name × Expr) + impVars : Array (Name × Expr) impObjFun : Expr impConstrs : Array Expr solution : Array Expr @@ -129,11 +90,11 @@ initialize atomExtension : AtomExtension ← do } return atomExtension -/-- Add a new atom to the library. -/ +/-- Add a new atom to the library (discrimination tree). -/ def addAtom (data : AtomData) : MetaM Unit := do let (_, _, expr) ← lambdaMetaTelescope data.expr let keys ← DiscrTree.mkPath expr - setEnv $ atomExtension.addEntry (← getEnv) (keys, data) + setEnv <| atomExtension.addEntry (← getEnv) (keys, data) /-- Get the atom tree. -/ def getAtomDiscrTree : MetaM (DiscrTree AtomData) := do diff --git a/CvxLean/Tactic/DCP/AtomLibrary/Fns/Abs.lean b/CvxLean/Tactic/DCP/AtomLibrary/Fns/Abs.lean index edf2de64..935383a7 100644 --- a/CvxLean/Tactic/DCP/AtomLibrary/Fns/Abs.lean +++ b/CvxLean/Tactic/DCP/AtomLibrary/Fns/Abs.lean @@ -4,6 +4,10 @@ import CvxLean.Tactic.DCP.AtomLibrary.Fns.Add import CvxLean.Tactic.DCP.AtomLibrary.Fns.Sub import CvxLean.Lib.Math.Data.Matrix +/-! +Absolute value atom (convex). +-/ + namespace CvxLean open Real @@ -13,23 +17,23 @@ vconditions implementationVars (t : ℝ) implementationObjective t implementationConstraints - (c_pos : posOrthCone (t - x)) - (c_neg : posOrthCone (t + x)) + (c_pos : nonnegOrthCone (t - x)) + (c_neg : nonnegOrthCone (t + x)) solution (t := abs x) solutionEqualsAtom rfl feasibility (c_pos : by - unfold posOrthCone + unfold nonnegOrthCone rw [sub_nonneg] exact le_abs_self x) (c_neg : by - unfold posOrthCone - rw [←neg_le_iff_add_nonneg'] + unfold nonnegOrthCone + rw [← neg_le_iff_add_nonneg'] exact neg_abs_le x) optimality by apply abs_le.2 - rw [←sub_nonneg, sub_neg_eq_add, add_comm, ←sub_nonneg (b := x)] - unfold posOrthCone at c_pos c_neg + rw [← sub_nonneg, sub_neg_eq_add, add_comm, ← sub_nonneg (b := x)] + unfold nonnegOrthCone at c_pos c_neg exact ⟨c_neg, c_pos⟩ vconditionElimination @@ -38,48 +42,47 @@ vconditions implementationVars (t : (Fin n) → ℝ) implementationObjective t implementationConstraints - (c_pos : Real.Vec.posOrthCone (t - x : (Fin n) → ℝ)) - (c_neg : Real.Vec.posOrthCone (t + x : (Fin n) → ℝ)) + (c_pos : Real.Vec.nonnegOrthCone (t - x : (Fin n) → ℝ)) + (c_neg : Real.Vec.nonnegOrthCone (t + x : (Fin n) → ℝ)) solution (t := abs x) solutionEqualsAtom rfl feasibility (c_pos : by - unfold Real.Vec.posOrthCone + unfold Real.Vec.nonnegOrthCone intros apply abs.feasibility0) (c_neg : by - unfold Real.Vec.posOrthCone + unfold Real.Vec.nonnegOrthCone intros apply abs.feasibility1) optimality by intros i - unfold Real.Vec.posOrthCone at c_pos c_neg + unfold Real.Vec.nonnegOrthCone at c_pos c_neg apply abs.optimality _ _ (c_pos i) (c_neg i) vconditionElimination -declare_atom Matrix.abs [convex] - (m : Nat)& (n : Nat)& (M : Matrix.{0,0,0} (Fin m) (Fin n) ℝ)? - : M.abs := +declare_atom Matrix.abs [convex] (m : Nat)& (n : Nat)& (M : Matrix.{0,0,0} (Fin m) (Fin n) ℝ)? : + M.abs := vconditions implementationVars (T : Matrix (Fin m) (Fin n) ℝ) implementationObjective T implementationConstraints - (c_pos : Real.Matrix.posOrthCone (T - M : Matrix (Fin m) (Fin n) ℝ)) - (c_neg : Real.Matrix.posOrthCone (T + M : Matrix (Fin m) (Fin n) ℝ)) + (c_pos : Real.Matrix.nonnegOrthCone (T - M : Matrix (Fin m) (Fin n) ℝ)) + (c_neg : Real.Matrix.nonnegOrthCone (T + M : Matrix (Fin m) (Fin n) ℝ)) solution (T := M.abs) solutionEqualsAtom rfl feasibility (c_pos : by - unfold Real.Matrix.posOrthCone + unfold Real.Matrix.nonnegOrthCone intros _ _ _ apply abs.feasibility0) (c_neg : by - unfold Real.Matrix.posOrthCone + unfold Real.Matrix.nonnegOrthCone intros _ _ _ apply abs.feasibility1) optimality by intros i j - unfold Real.Matrix.posOrthCone at c_pos c_neg + unfold Real.Matrix.nonnegOrthCone at c_pos c_neg apply abs.optimality _ _ (c_pos i j) (c_neg i j) vconditionElimination diff --git a/CvxLean/Tactic/DCP/AtomLibrary/Fns/Add.lean b/CvxLean/Tactic/DCP/AtomLibrary/Fns/Add.lean index 56babb01..02bdd9a4 100644 --- a/CvxLean/Tactic/DCP/AtomLibrary/Fns/Add.lean +++ b/CvxLean/Tactic/DCP/AtomLibrary/Fns/Add.lean @@ -1,6 +1,10 @@ import CvxLean.Tactic.DCP.AtomCmd import CvxLean.Lib.Math.Data.Matrix +/-! +Addition atom (affine). +-/ + namespace CvxLean declare_atom add [affine] (x : ℝ)+ (y : ℝ)+ : x + y := @@ -9,7 +13,7 @@ homogenity by simp [mul_add] additivity by simp only [add_zero, add_assoc, add_comm] - rw [add_comm x' y', ←add_assoc y y' x', add_comm _ x'] + rw [add_comm x' y', ← add_assoc y y' x', add_comm _ x'] optimality fun _ _ => add_le_add declare_atom Vec.add [affine] (m : Nat)& (x : Fin m → ℝ)+ (y : Fin m → ℝ)+ : x + y := diff --git a/CvxLean/Tactic/DCP/AtomLibrary/Fns/All.lean b/CvxLean/Tactic/DCP/AtomLibrary/Fns/All.lean index e474dea2..2a217b9c 100644 --- a/CvxLean/Tactic/DCP/AtomLibrary/Fns/All.lean +++ b/CvxLean/Tactic/DCP/AtomLibrary/Fns/All.lean @@ -25,8 +25,6 @@ import CvxLean.Tactic.DCP.AtomLibrary.Fns.MulVec import CvxLean.Tactic.DCP.AtomLibrary.Fns.Neg import CvxLean.Tactic.DCP.AtomLibrary.Fns.Norm import CvxLean.Tactic.DCP.AtomLibrary.Fns.Nth -import CvxLean.Tactic.DCP.AtomLibrary.Fns.Perspective -import CvxLean.Tactic.DCP.AtomLibrary.Fns.PosSemidef import CvxLean.Tactic.DCP.AtomLibrary.Fns.Power import CvxLean.Tactic.DCP.AtomLibrary.Fns.QuadOverLin import CvxLean.Tactic.DCP.AtomLibrary.Fns.SMul diff --git a/CvxLean/Tactic/DCP/AtomLibrary/Fns/CumSum.lean b/CvxLean/Tactic/DCP/AtomLibrary/Fns/CumSum.lean index 244e7542..38b26121 100644 --- a/CvxLean/Tactic/DCP/AtomLibrary/Fns/CumSum.lean +++ b/CvxLean/Tactic/DCP/AtomLibrary/Fns/CumSum.lean @@ -1,6 +1,10 @@ import CvxLean.Tactic.DCP.AtomCmd import CvxLean.Lib.Math.Data.Vec +/-! +Cumulative sum atom (affine). +-/ + namespace CvxLean open Vec @@ -9,16 +13,16 @@ declare_atom Vec.cumsum [affine] (n : Nat)& (x : Fin n → ℝ)+ : cumsum x := bconditions homogenity by unfold Vec.cumsum; split_ifs - . funext i; simp [Finset.mul_sum] - . ring + · funext i; simp [Finset.mul_sum] + · ring additivity by unfold Vec.cumsum; split_ifs - . funext i; simp [Finset.sum_add_distrib] - . ring + · funext i; simp [Finset.sum_add_distrib] + · ring optimality by intro x' hxx' unfold Vec.cumsum; split_ifs - . intros i; apply Finset.sum_le_sum; intros j _; exact hxx' j - . simp + · intros i; apply Finset.sum_le_sum; intros j _; exact hxx' j + · simp end CvxLean diff --git a/CvxLean/Tactic/DCP/AtomLibrary/Fns/Diag.lean b/CvxLean/Tactic/DCP/AtomLibrary/Fns/Diag.lean index ae2efb26..cbcb33b8 100644 --- a/CvxLean/Tactic/DCP/AtomLibrary/Fns/Diag.lean +++ b/CvxLean/Tactic/DCP/AtomLibrary/Fns/Diag.lean @@ -1,6 +1,10 @@ import CvxLean.Tactic.DCP.AtomCmd import CvxLean.Lib.Math.Data.Matrix +/-! +Atom for the operation `ℝⁿˣⁿ → ℝⁿ` that outputs the diagonal of a matrix (affine). +-/ + namespace CvxLean declare_atom Matrix.diag [affine] (n : ℕ)& (A : Matrix.{0,0,0} (Fin n) (Fin n) ℝ)+ : A.diag := diff --git a/CvxLean/Tactic/DCP/AtomLibrary/Fns/Diagonal.lean b/CvxLean/Tactic/DCP/AtomLibrary/Fns/Diagonal.lean index 5d234d2c..25cdb3fa 100644 --- a/CvxLean/Tactic/DCP/AtomLibrary/Fns/Diagonal.lean +++ b/CvxLean/Tactic/DCP/AtomLibrary/Fns/Diagonal.lean @@ -1,6 +1,11 @@ import CvxLean.Tactic.DCP.AtomCmd import CvxLean.Lib.Math.Data.Matrix +/-! +Atom for the operation that turns a vector into a diagonal matrix with zeros outside the diagonal +(affine). +-/ + namespace CvxLean open Matrix diff --git a/CvxLean/Tactic/DCP/AtomLibrary/Fns/Div.lean b/CvxLean/Tactic/DCP/AtomLibrary/Fns/Div.lean index f93fd087..356f432e 100644 --- a/CvxLean/Tactic/DCP/AtomLibrary/Fns/Div.lean +++ b/CvxLean/Tactic/DCP/AtomLibrary/Fns/Div.lean @@ -1,5 +1,9 @@ import CvxLean.Tactic.DCP.AtomCmd +/-! +Division atom (affine). +-/ + namespace CvxLean open Real diff --git a/CvxLean/Tactic/DCP/AtomLibrary/Fns/DotProduct.lean b/CvxLean/Tactic/DCP/AtomLibrary/Fns/DotProduct.lean index 1f7df35f..b8639c7b 100644 --- a/CvxLean/Tactic/DCP/AtomLibrary/Fns/DotProduct.lean +++ b/CvxLean/Tactic/DCP/AtomLibrary/Fns/DotProduct.lean @@ -1,8 +1,13 @@ import CvxLean.Tactic.DCP.AtomCmd +/-! +Dot product atom (affine). +-/ + namespace CvxLean -declare_atom Vec.dotProduct1 [affine] (m : Nat)& (x : Fin m → ℝ)& (y : Fin m → ℝ)? : Matrix.dotProduct x y := +declare_atom Vec.dotProduct1 [affine] (m : Nat)& (x : Fin m → ℝ)& (y : Fin m → ℝ)? : + Matrix.dotProduct x y := bconditions homogenity by rw [Matrix.dotProduct_zero, smul_zero, add_zero, add_zero, @@ -11,7 +16,8 @@ additivity by rw [Matrix.dotProduct_zero, add_zero, Matrix.dotProduct_add] optimality le_refl _ -declare_atom Vec.dotProduct2 [affine] (m : Nat)& (x : Fin m → ℝ)? (y : Fin m → ℝ)& : Matrix.dotProduct x y := +declare_atom Vec.dotProduct2 [affine] (m : Nat)& (x : Fin m → ℝ)? (y : Fin m → ℝ)& : + Matrix.dotProduct x y := bconditions homogenity by rw [Matrix.zero_dotProduct, smul_zero, add_zero, add_zero, diff --git a/CvxLean/Tactic/DCP/AtomLibrary/Fns/Entr.lean b/CvxLean/Tactic/DCP/AtomLibrary/Fns/Entr.lean index 87e2bdb7..fd7f511d 100644 --- a/CvxLean/Tactic/DCP/AtomLibrary/Fns/Entr.lean +++ b/CvxLean/Tactic/DCP/AtomLibrary/Fns/Entr.lean @@ -3,6 +3,10 @@ import CvxLean.Tactic.DCP.AtomLibrary.Sets.Cones import CvxLean.Lib.Math.Data.Real import CvxLean.Lib.Math.Data.Vec +/-! +Entropy atom (concave). +-/ + namespace CvxLean open Real @@ -22,7 +26,7 @@ feasibility | inl h => left refine ⟨h, ?_⟩ - rw [neg_div, mul_comm _ (log x), ←mul_div] + rw [neg_div, mul_comm _ (log x), ← mul_div] rw [div_self (ne_of_lt h).symm, mul_one, exp_neg, exp_log h] rw [inv_eq_one_div, mul_div, mul_one, div_self (ne_of_lt h).symm] | inr h => @@ -34,8 +38,8 @@ optimality by unfold expCone at c cases c with | inl c => - rw [mul_comm, ←neg_mul, ←div_le_iff c.1] - rw [←exp_le_exp, exp_neg, exp_log c.1, inv_eq_one_div] + rw [mul_comm, ← neg_mul, ← div_le_iff c.1] + rw [← exp_le_exp, exp_neg, exp_log c.1, inv_eq_one_div] rw [le_div_iff c.1, mul_comm] exact c.2 | inr c => diff --git a/CvxLean/Tactic/DCP/AtomLibrary/Fns/Exp.lean b/CvxLean/Tactic/DCP/AtomLibrary/Fns/Exp.lean index 3827b0d8..4e291822 100644 --- a/CvxLean/Tactic/DCP/AtomLibrary/Fns/Exp.lean +++ b/CvxLean/Tactic/DCP/AtomLibrary/Fns/Exp.lean @@ -2,6 +2,10 @@ import CvxLean.Tactic.DCP.AtomCmd import CvxLean.Tactic.DCP.AtomLibrary.Sets.Cones import CvxLean.Lib.Math.Data.Vec +/-! +Exponential function atom (convex). +-/ + namespace CvxLean open Real @@ -19,7 +23,7 @@ feasibility unfold expCone; simp) optimality by intros x' hx - rw [←exp_iff_expCone] at c_exp + rw [← exp_iff_expCone] at c_exp have hexpleexp := exp_le_exp.2 hx exact hexpleexp.trans c_exp vconditionElimination diff --git a/CvxLean/Tactic/DCP/AtomLibrary/Fns/FromBlocks.lean b/CvxLean/Tactic/DCP/AtomLibrary/Fns/FromBlocks.lean index f52c6c32..00a939c8 100644 --- a/CvxLean/Tactic/DCP/AtomLibrary/Fns/FromBlocks.lean +++ b/CvxLean/Tactic/DCP/AtomLibrary/Fns/FromBlocks.lean @@ -1,6 +1,10 @@ import CvxLean.Tactic.DCP.AtomCmd import CvxLean.Lib.Math.Data.Matrix +/-! +Atom for the operation that makes a matrix from four blocks (affine). +-/ + namespace CvxLean open Matrix @@ -19,9 +23,9 @@ additivity by optimality by intros A' B' C' D' hA hB hC hD i j cases i <;> cases j <;> simp [fromBlocks] - . exact hA _ _ - . exact hB _ _ - . exact hC _ _ - . exact hD _ _ + · exact hA _ _ + · exact hB _ _ + · exact hC _ _ + · exact hD _ _ end CvxLean diff --git a/CvxLean/Tactic/DCP/AtomLibrary/Fns/GeoMean.lean b/CvxLean/Tactic/DCP/AtomLibrary/Fns/GeoMean.lean index 772adeb0..afd4158e 100644 --- a/CvxLean/Tactic/DCP/AtomLibrary/Fns/GeoMean.lean +++ b/CvxLean/Tactic/DCP/AtomLibrary/Fns/GeoMean.lean @@ -8,11 +8,15 @@ import CvxLean.Tactic.DCP.AtomLibrary.Fns.Mul import CvxLean.Tactic.DCP.AtomLibrary.Fns.Exp import CvxLean.Lib.Math.Data.Vec +/-! +Geometric mean atom (concave). +-/ + namespace CvxLean open Real --- TODO(RFM): generalize to x : Fin n → ℝ +-- TODO: generalize to x : Fin n → ℝ declare_atom geoMean [concave] (x : ℝ)+ (y : ℝ)+ : sqrt (x * y) := vconditions (hx : 0 ≤ x) diff --git a/CvxLean/Tactic/DCP/AtomLibrary/Fns/Huber.lean b/CvxLean/Tactic/DCP/AtomLibrary/Fns/Huber.lean index 45542ade..571c4486 100644 --- a/CvxLean/Tactic/DCP/AtomLibrary/Fns/Huber.lean +++ b/CvxLean/Tactic/DCP/AtomLibrary/Fns/Huber.lean @@ -9,6 +9,11 @@ import CvxLean.Tactic.DCP.AtomLibrary.Fns.Mul import CvxLean.Tactic.DCP.AtomLibrary.Fns.SMul import CvxLean.Lib.Math.Data.Real import CvxLean.Lib.Math.Data.Vec +import CvxLean.Meta.Util.Debug + +/-! +Unit half-with Huber loss function atom (convex). +-/ namespace CvxLean @@ -41,7 +46,7 @@ optimality by { by_cases hwx : w ≤ |x| { have hsq : |x| ^ (2 : ℝ) - w ^ (2 : ℝ) = (|x| + w) * (|x| - w) := by ring_nf; simp - rw [←sub_le_iff_le_add, ←sq_abs, ←rpow_two, ←rpow_two, hsq] + rw [← sub_le_iff_le_add, ← sq_abs, ← rpow_two, ← rpow_two, hsq] apply mul_le_mul <;> linarith } { replace hwx := not_le.mp hwx have hw : 0 ≤ w := by linarith [abs_nonneg x] @@ -50,14 +55,14 @@ optimality by apply add_nonneg <;> norm_num [c2, hw] } have hcx : 0 < |x| := abs_pos.mpr hxz have hx2w2 := mul_lt_mul hwx (le_of_lt hwx) hcx hw - rw [←pow_two, sq_abs, ←pow_two] at hx2w2 + rw [← pow_two, sq_abs, ← pow_two] at hx2w2 have hv : 0 ≤ 2 * v := by linarith have hx2w22v := add_le_add (le_of_lt hx2w2) hv linarith } } { have h2xsub1 : 2 * |x| - 1 ≤ 2 * v + (2 * w - 1) := by linarith apply le_trans h2xsub1 apply add_le_add (le_refl _) - rw [←zero_add (_ - _), ←le_sub_iff_add_le] + rw [← zero_add (_ - _), ← le_sub_iff_add_le] have hwsub12 : w * w - (2 * w - 1) = (w - 1) * (w - 1) := by ring_nf rw [pow_two, hwsub12] exact mul_self_nonneg _ } @@ -76,37 +81,37 @@ solution (w := fun i => if |x i| ≤ 1 then |x i| else 1) solutionEqualsAtom by ext i - simp [Vec.huber, ←huber.solEqAtom (x i)] + simp [Vec.huber, ← huber.solEqAtom (x i)] feasibility (c1 : by dsimp intros i have h := huber.feasibility0 (x i) - unfold Real.posOrthCone at h + unfold Real.nonnegOrthCone at h simpa using h) (c2 : by dsimp intros i have h := huber.feasibility1 (x i) - unfold Real.posOrthCone at h + unfold Real.nonnegOrthCone at h simpa using h) (c3 : by dsimp intros i have h := huber.feasibility2 (x i) - unfold Real.posOrthCone at h + unfold Real.nonnegOrthCone at h simpa using h) optimality by intros i simp [Vec.huber] - rw [←rpow_two] + rw [← rpow_two] apply huber.optimality (x i) (v i) (w i) ((w i) ^ 2) (abs (x i)) - { unfold posOrthCone; simpa using c1 i } - { unfold posOrthCone; simpa using c2 i } - { unfold posOrthCone; simpa using c3 i } + { unfold nonnegOrthCone; simpa using c1 i } + { unfold nonnegOrthCone; simpa using c2 i } + { unfold nonnegOrthCone; simpa using c3 i } { unfold rotatedSoCone; simp [sq_nonneg]; norm_num } - { unfold posOrthCone; simp [sub_nonneg, abs] } - { unfold posOrthCone; simp [← sub_le_iff_le_add, abs] } + { unfold nonnegOrthCone; simp [sub_nonneg, abs] } + { unfold nonnegOrthCone; simp [← sub_le_iff_le_add, abs] } vconditionElimination diff --git a/CvxLean/Tactic/DCP/AtomLibrary/Fns/Inv.lean b/CvxLean/Tactic/DCP/AtomLibrary/Fns/Inv.lean index 87ce4b9b..948d4fdf 100644 --- a/CvxLean/Tactic/DCP/AtomLibrary/Fns/Inv.lean +++ b/CvxLean/Tactic/DCP/AtomLibrary/Fns/Inv.lean @@ -8,13 +8,17 @@ import CvxLean.Tactic.DCP.AtomLibrary.Fns.Mul import CvxLean.Tactic.DCP.AtomLibrary.Fns.Power import CvxLean.Tactic.DCP.AtomLibrary.Fns.Transpose +/-! +Inverse function atom (convex). +-/ + namespace CvxLean open Real --- TODO: These are just the same as `powNegOne`. We need them for pattern --- mathcing. The issue is that implementationConstraints does not know about --- vconditions. If we fix that, then we can just say (c : x ^ (-1) ≤ t). +-- TODO: These are just the same as `powNegOne`. We need them for pattern matching. The issue is +-- that `implementationConstraints` does not know about `vconditions`. If we fix that, then we can +-- just say `(c : x ^ (-1) ≤ t)`. declare_atom inv [convex] (x : ℝ)- : x⁻¹ := vconditions @@ -43,7 +47,7 @@ optimality by have hxpos : 0 < x := by cases (lt_or_eq_of_le c3) with | inl h => exact h - | inr h => rw [←h] at c1; linarith + | inr h => rw [← h] at c1; linarith have hypos := lt_of_lt_of_le hxpos hy rw [inv_eq_pow_neg_one hypos] rw [rpow_neg_one, inv_eq_one_div, div_le_iff hypos] @@ -55,7 +59,7 @@ vconditionElimination rw [soCone_add_sub_two_of_nonneg c2 c3] at c1 cases (lt_or_eq_of_le c3) with | inl h => linarith - | inr h => rw [←h] at c1; linarith) + | inr h => rw [← h] at c1; linarith) declare_atom oneDiv [convex] (x : ℝ)- : 1 / x := vconditions @@ -84,7 +88,7 @@ optimality by have hxpos : 0 < x := by cases (lt_or_eq_of_le c3) with | inl h => exact h - | inr h => rw [←h] at c1; linarith + | inr h => rw [← h] at c1; linarith have hypos := lt_of_lt_of_le hxpos hy rw [one_div_eq_pow_neg_one hypos] rw [rpow_neg_one, inv_eq_one_div, div_le_iff hypos] @@ -96,7 +100,7 @@ vconditionElimination rw [soCone_add_sub_two_of_nonneg c2 c3] at c1 cases (lt_or_eq_of_le c3) with | inl h => linarith - | inr h => rw [←h] at c1; linarith) + | inr h => rw [← h] at c1; linarith) open Matrix @@ -135,7 +139,7 @@ optimality by have hxipos : 0 < x i := by cases (lt_or_eq_of_le (c3 i)) with | inl h => exact h - | inr h => simp [←h] at c1i; linarith + | inr h => simp [← h] at c1i; linarith have hyipos := lt_of_lt_of_le hxipos (hy i) rw [one_div_eq_pow_neg_one hyipos] rw [rpow_neg_one, inv_eq_one_div, div_le_iff hyipos] @@ -151,6 +155,6 @@ vconditionElimination | inl h => have hyi := hy i simp at h hyi ⊢; linarith - | inr h => simp [←h] at c1i; linarith) + | inr h => simp [← h] at c1i; linarith) end CvxLean diff --git a/CvxLean/Tactic/DCP/AtomLibrary/Fns/KLDiv.lean b/CvxLean/Tactic/DCP/AtomLibrary/Fns/KLDiv.lean index 61107270..d7f9298f 100644 --- a/CvxLean/Tactic/DCP/AtomLibrary/Fns/KLDiv.lean +++ b/CvxLean/Tactic/DCP/AtomLibrary/Fns/KLDiv.lean @@ -6,12 +6,15 @@ import CvxLean.Tactic.DCP.AtomLibrary.Fns.Sub import CvxLean.Lib.Math.Data.Real import CvxLean.Lib.Math.Data.Vec +/-! +Kullback-Leibler (KL) divergence, or relative entropy atom (convex). +-/ + namespace CvxLean open Real --- TODO(RFM): same issue as in `quadOverLin`, no clear way to have 0 < y as a --- vcondition, the exp y' ≤ y trick makes the solver stall. +-- TODO: same issue as in `quadOverLin`, no clear way to have `0 < y` as a `vcondition`. declare_atom klDiv [convex] (x : ℝ)? (y : ℝ)? : x * log (x / y) - x + y := vconditions (hx : 0 ≤ x) @@ -33,7 +36,7 @@ feasibility left refine ⟨hx, ?_⟩ have hypos : 0 < y := by positivity - rw [mul_comm _ (log _), ←neg_mul, ←mul_div, div_self (ne_of_gt hx)] + rw [mul_comm _ (log _), ← neg_mul, ← mul_div, div_self (ne_of_gt hx)] rw [mul_one, exp_neg, exp_log (div_pos hx hypos), inv_div, mul_div] rw [div_le_iff hx, mul_comm] | inr hx => @@ -54,8 +57,8 @@ optimality by have hrhs : y - x - t = -t + (y - x) := by ring conv => rhs; rw [hrhs] apply add_le_add_right - rw [mul_comm, ←le_div_iff c.1, neg_div] - rw [←exp_le_exp, exp_log (div_pos hxpos hypos), exp_neg, ←inv_div] + rw [mul_comm, ← le_div_iff c.1, neg_div] + rw [← exp_le_exp, exp_log (div_pos hxpos hypos), exp_neg, ← inv_div] rw [inv_le_inv (div_pos hypos hxpos) (exp_pos _)] rw [le_div_iff hxpos, mul_comm] exact c.2 @@ -88,20 +91,20 @@ feasibility (c2 : by dsimp have h := klDiv.feasibility1 x y hx hy - unfold posOrthCone at h + unfold nonnegOrthCone at h simpa using h) optimality by apply klDiv.optimality x y t c1 - { unfold posOrthCone expCone at * + { unfold nonnegOrthCone expCone at * simpa using c2 } vconditionElimination (hx : by apply klDiv.vcondElim0 x y t c1 - unfold posOrthCone at * + unfold nonnegOrthCone at * simpa using c2) (hy : by apply klDiv.vcondElim1 x y t c1 - unfold posOrthCone expCone at * + unfold nonnegOrthCone expCone at * simpa using c2) declare_atom Vec.klDiv [convex] (m : Nat)& (x : Fin m → ℝ)? (y : Fin m → ℝ)? : @@ -127,23 +130,23 @@ feasibility simp [Vec.klDiv, klDiv] intros i have h := klDiv.feasibility1 (x i) (y i) (hx i) (hy i) - unfold posOrthCone at h + unfold nonnegOrthCone at h simpa [Vec.const] using h) optimality fun i => by unfold Vec.expCone at c1 apply klDiv.optimality (x i) (y i) (t i) (c1 i) - unfold posOrthCone expCone at * + unfold nonnegOrthCone expCone at * simpa [Vec.const] using (c2 i) vconditionElimination (hx : fun i => by unfold Vec.expCone at c1 apply klDiv.vcondElim0 (x i) (y i) (t i) (c1 i) - unfold posOrthCone expCone at * + unfold nonnegOrthCone expCone at * simpa [Vec.const] using (c2 i)) (hy : fun i => by unfold Vec.expCone at c1 apply klDiv.vcondElim1 (x i) (y i) (t i) (c1 i) - unfold posOrthCone expCone at * + unfold nonnegOrthCone expCone at * simpa [Vec.const] using (c2 i)) end CvxLean diff --git a/CvxLean/Tactic/DCP/AtomLibrary/Fns/Log.lean b/CvxLean/Tactic/DCP/AtomLibrary/Fns/Log.lean index 51dd301d..e48d27bc 100644 --- a/CvxLean/Tactic/DCP/AtomLibrary/Fns/Log.lean +++ b/CvxLean/Tactic/DCP/AtomLibrary/Fns/Log.lean @@ -2,6 +2,10 @@ import CvxLean.Tactic.DCP.AtomCmd import CvxLean.Tactic.DCP.AtomLibrary.Sets.Cones import CvxLean.Lib.Math.Data.Vec +/-! +Logarithm atom (concave). +-/ + namespace CvxLean open Real diff --git a/CvxLean/Tactic/DCP/AtomLibrary/Fns/LogDet.lean b/CvxLean/Tactic/DCP/AtomLibrary/Fns/LogDet.lean index f6c0074e..d68166d0 100644 --- a/CvxLean/Tactic/DCP/AtomLibrary/Fns/LogDet.lean +++ b/CvxLean/Tactic/DCP/AtomLibrary/Fns/LogDet.lean @@ -1,7 +1,7 @@ import CvxLean.Tactic.DCP.AtomCmd import CvxLean.Tactic.DCP.AtomLibrary.Sets.Cones +import CvxLean.Tactic.DCP.AtomLibrary.Sets.PosSemidef import CvxLean.Tactic.DCP.AtomLibrary.Fns.Sum -import CvxLean.Tactic.DCP.AtomLibrary.Fns.PosSemidef import CvxLean.Tactic.DCP.AtomLibrary.Fns.FromBlocks import CvxLean.Tactic.DCP.AtomLibrary.Fns.Diag import CvxLean.Tactic.DCP.AtomLibrary.Fns.Diagonal @@ -9,6 +9,10 @@ import CvxLean.Tactic.DCP.AtomLibrary.Fns.Transpose import CvxLean.Tactic.DCP.AtomLibrary.Fns.ToUpperTri import CvxLean.Lib.Math.LogDet +/-! +Log-det atom (concave). +-/ + namespace CvxLean open Matrix @@ -17,8 +21,7 @@ declare_atom Matrix.logDet [concave] (n : ℕ)& (A : Matrix.{0,0,0} (Fin n) (Fin n) ℝ)? : Real.log A.det := vconditions (hA : A.PosDef) implementationVars (t : Fin n → ℝ) (Y : Matrix (Fin n) (Fin n) ℝ) --- The lower left values of `Y` are unused. CVXPY uses a vector `z` instead of a --- matrix `Y`. +-- The lower left values of `Y` are unused. CVXPY uses a vector `z` instead of a matrix `Y`. implementationObjective Vec.sum t implementationConstraints (c_exp : Real.Vec.expCone t 1 Y.diag) diff --git a/CvxLean/Tactic/DCP/AtomLibrary/Fns/LogSumExp.lean b/CvxLean/Tactic/DCP/AtomLibrary/Fns/LogSumExp.lean index 1c8c9f4f..738b9fcd 100644 --- a/CvxLean/Tactic/DCP/AtomLibrary/Fns/LogSumExp.lean +++ b/CvxLean/Tactic/DCP/AtomLibrary/Fns/LogSumExp.lean @@ -9,6 +9,10 @@ import CvxLean.Tactic.DCP.AtomLibrary.Fns.VecCons import CvxLean.Tactic.DCP.AtomLibrary.Fns.VecConst import CvxLean.Lib.Math.LogSumExp +/-! +Log-sum-exp atom (convex). +-/ + namespace CvxLean open Real @@ -32,7 +36,7 @@ feasibility optimality by intros y hy simp [Vec.sum_exp_eq_sum_div, div_le_iff (exp_pos _)] at c - rw [←log_exp t, log_le_log_iff (Vec.sum_exp_pos hn y) (exp_pos _)] + rw [← log_exp t, log_le_log_iff (Vec.sum_exp_pos hn y) (exp_pos _)] refine le_trans ?_ c apply Finset.sum_le_sum intros i _ diff --git a/CvxLean/Tactic/DCP/AtomLibrary/Fns/MatrixConst.lean b/CvxLean/Tactic/DCP/AtomLibrary/Fns/MatrixConst.lean index da27ce6a..726a2fdf 100644 --- a/CvxLean/Tactic/DCP/AtomLibrary/Fns/MatrixConst.lean +++ b/CvxLean/Tactic/DCP/AtomLibrary/Fns/MatrixConst.lean @@ -1,6 +1,10 @@ import CvxLean.Tactic.DCP.AtomCmd import CvxLean.Lib.Math.Data.Matrix +/-! +Atom for the operation of making a matrix with all entries set to the same value (affine). +-/ + namespace CvxLean declare_atom Matrix.const [affine] (m : Type)& (n : Type)& (k : ℝ)& : diff --git a/CvxLean/Tactic/DCP/AtomLibrary/Fns/MatrixMul.lean b/CvxLean/Tactic/DCP/AtomLibrary/Fns/MatrixMul.lean index e46e57e1..5358571d 100644 --- a/CvxLean/Tactic/DCP/AtomLibrary/Fns/MatrixMul.lean +++ b/CvxLean/Tactic/DCP/AtomLibrary/Fns/MatrixMul.lean @@ -1,6 +1,10 @@ import CvxLean.Tactic.DCP.AtomCmd import CvxLean.Lib.Math.Data.Matrix +/-! +Matrix multiplication atoms (affine). +-/ + namespace CvxLean declare_atom Matrix.mul1 [affine] (m : Type)& (hm : Fintype.{0} m)& diff --git a/CvxLean/Tactic/DCP/AtomLibrary/Fns/Max.lean b/CvxLean/Tactic/DCP/AtomLibrary/Fns/Max.lean index 2e98fe93..ada5c07d 100644 --- a/CvxLean/Tactic/DCP/AtomLibrary/Fns/Max.lean +++ b/CvxLean/Tactic/DCP/AtomLibrary/Fns/Max.lean @@ -2,6 +2,10 @@ import CvxLean.Tactic.DCP.AtomCmd import CvxLean.Tactic.DCP.AtomLibrary.Sets.Le import CvxLean.Lib.Math.Data.Vec +/-! +Max atom (convex). +-/ + namespace CvxLean open Real diff --git a/CvxLean/Tactic/DCP/AtomLibrary/Fns/MaximizeNeg.lean b/CvxLean/Tactic/DCP/AtomLibrary/Fns/MaximizeNeg.lean index c6b1aebf..1d94f6cf 100644 --- a/CvxLean/Tactic/DCP/AtomLibrary/Fns/MaximizeNeg.lean +++ b/CvxLean/Tactic/DCP/AtomLibrary/Fns/MaximizeNeg.lean @@ -1,5 +1,10 @@ import CvxLean.Tactic.DCP.AtomCmd +/-! +`maximizeNeg` is just negation, really, but needs its own atom as it has its own identifier +(affine). +-/ + namespace CvxLean declare_atom maximizeNeg [affine] (x : ℝ)- : maximizeNeg x := diff --git a/CvxLean/Tactic/DCP/AtomLibrary/Fns/Min.lean b/CvxLean/Tactic/DCP/AtomLibrary/Fns/Min.lean index a936af13..f726c3bf 100644 --- a/CvxLean/Tactic/DCP/AtomLibrary/Fns/Min.lean +++ b/CvxLean/Tactic/DCP/AtomLibrary/Fns/Min.lean @@ -2,6 +2,10 @@ import CvxLean.Tactic.DCP.AtomCmd import CvxLean.Tactic.DCP.AtomLibrary.Sets.Le import CvxLean.Lib.Math.Data.Vec +/-! +Min atom (concave). +-/ + namespace CvxLean open Real diff --git a/CvxLean/Tactic/DCP/AtomLibrary/Fns/Mul.lean b/CvxLean/Tactic/DCP/AtomLibrary/Fns/Mul.lean index d6fea771..1398b9b6 100644 --- a/CvxLean/Tactic/DCP/AtomLibrary/Fns/Mul.lean +++ b/CvxLean/Tactic/DCP/AtomLibrary/Fns/Mul.lean @@ -1,5 +1,9 @@ import CvxLean.Tactic.DCP.AtomCmd +/-! +Multiplication atoms (affine). +-/ + namespace CvxLean declare_atom mul1 [affine] (x : ℝ)& (y : ℝ)+ : x * y := diff --git a/CvxLean/Tactic/DCP/AtomLibrary/Fns/MulVec.lean b/CvxLean/Tactic/DCP/AtomLibrary/Fns/MulVec.lean index df342282..abba7512 100644 --- a/CvxLean/Tactic/DCP/AtomLibrary/Fns/MulVec.lean +++ b/CvxLean/Tactic/DCP/AtomLibrary/Fns/MulVec.lean @@ -1,5 +1,9 @@ import CvxLean.Tactic.DCP.AtomCmd +/-! +Atom for matrix-vector multiplication (affine). +-/ + namespace CvxLean declare_atom Matrix.mulVec [affine] (n : ℕ)& (m : ℕ)& diff --git a/CvxLean/Tactic/DCP/AtomLibrary/Fns/Neg.lean b/CvxLean/Tactic/DCP/AtomLibrary/Fns/Neg.lean index 79738ae7..9bfeb7c1 100644 --- a/CvxLean/Tactic/DCP/AtomLibrary/Fns/Neg.lean +++ b/CvxLean/Tactic/DCP/AtomLibrary/Fns/Neg.lean @@ -1,5 +1,9 @@ import CvxLean.Tactic.DCP.AtomCmd +/-! +Atom for negation (affine). +-/ + namespace CvxLean declare_atom neg [affine] (x : ℝ)- : - x := diff --git a/CvxLean/Tactic/DCP/AtomLibrary/Fns/Norm.lean b/CvxLean/Tactic/DCP/AtomLibrary/Fns/Norm.lean index 66f5c6f1..4a634967 100644 --- a/CvxLean/Tactic/DCP/AtomLibrary/Fns/Norm.lean +++ b/CvxLean/Tactic/DCP/AtomLibrary/Fns/Norm.lean @@ -4,6 +4,8 @@ import CvxLean.Tactic.DCP.AtomLibrary.Fns.VecCons import CvxLean.Lib.Math.Data.Vec /-! +Norm atoms (convex). + ### TODO This is not defined in full generality. It can be made increasing or decreasing in each `xᵢ` diff --git a/CvxLean/Tactic/DCP/AtomLibrary/Fns/Nth.lean b/CvxLean/Tactic/DCP/AtomLibrary/Fns/Nth.lean index 6bc87753..c631b39e 100644 --- a/CvxLean/Tactic/DCP/AtomLibrary/Fns/Nth.lean +++ b/CvxLean/Tactic/DCP/AtomLibrary/Fns/Nth.lean @@ -1,5 +1,9 @@ import CvxLean.Tactic.DCP.AtomCmd +/-! +Atoms for vector and matrix accesses (affine). +-/ + namespace CvxLean declare_atom Vec.nth [affine] (m : Nat)& (x : Fin m → ℝ)? (i : Fin m)& : x i := @@ -14,7 +18,8 @@ additivity by ring optimality le_refl _ -declare_atom Matrix.nth [affine] (m : Nat)& (X : Matrix.{0,0,0} (Fin m) (Fin m) ℝ)? (i : Fin m)& : X i := +declare_atom Matrix.nth [affine] (m : Nat)& (X : Matrix.{0,0,0} (Fin m) (Fin m) ℝ)? (i : Fin m)& : + X i := bconditions homogenity by rw [Pi.zero_apply, add_zero, smul_zero, add_zero] @@ -24,7 +29,8 @@ additivity by rfl optimality le_refl _ -declare_atom Matrix.nth2 [affine] (m : Nat)& (X : Matrix.{0,0,0} (Fin m) (Fin m) ℝ)? (i : Fin m)& (j : Fin m)& : X i j := +declare_atom Matrix.nth2 [affine] (m : Nat)& (X : Matrix.{0,0,0} (Fin m) (Fin m) ℝ)? (i : Fin m)& + (j : Fin m)& : X i j := bconditions homogenity by rw [Pi.zero_apply, Pi.zero_apply, smul_zero] diff --git a/CvxLean/Tactic/DCP/AtomLibrary/Fns/Perspective.lean b/CvxLean/Tactic/DCP/AtomLibrary/Fns/Perspective.lean deleted file mode 100644 index 96e4bbda..00000000 --- a/CvxLean/Tactic/DCP/AtomLibrary/Fns/Perspective.lean +++ /dev/null @@ -1,89 +0,0 @@ -import CvxLean.Tactic.DCP.AtomCmd -import CvxLean.Tactic.DCP.AtomLibrary.Sets.Cones - -namespace CvxLean - --- -- Convex --- class IsRealAtom (f : ℝ → ℝ) where --- (impObj : ℝ × ℝ → ℝ) --- (impObjAffine : ∀ (x : ℝ) (t : ℝ) (s : ℝ), s * impObj ⟨x, t⟩ = impObj ⟨s * x, s * t⟩) --- (impConstr : ℝ × ℝ → Prop) --- (sol : ℝ → ℝ) --- (solEqAtom : ∀ (x : ℝ), impObj ⟨x, sol x⟩ = f x) --- (feasibility : ∀ (x : ℝ), impConstr ⟨x, sol x⟩) --- (optimality : ∀ (x : ℝ) (t : ℝ), impConstr ⟨x, t⟩ → f x ≤ impObj ⟨x, t⟩) - --- declare_atom IsSimpleRealAtom.impConstr [cone] (f : ℝ → ℝ)? (hatom : IsRealAtom f)? (x : ℝ)? (t : ℝ)? (s : ℝ)? : hatom.impConstr ⟨x / s, t / s⟩ := --- optimality le_refl _ - --- declare_atom IsSimpleRealAtom.impObj [convex] (f : ℝ → ℝ)? (hatom : IsRealAtom f)? (x : ℝ)? (t : ℝ)? (s : ℝ)? : hatom.impObj ⟨x, t⟩ := --- vconditions --- implementationVars (z : ℝ) --- implementationObjective sorry --- implementationConstraints --- (c : True) --- solution (z := sorry) --- solutionEqualsAtom sorry --- feasibility --- (c : trivial) --- optimality by --- sorry --- vconditionElimination - --- #check EuclideanDomain.mul_div_cancel_left - --- -- IsConvexReducible? - --- Another option: declare_atom_unsafe, and do everything in meta. - --- declare_atom perspective [convex] (f : ℝ → ℝ)? (x : ℝ)? (hatom : IsRealAtom f)? (s : ℝ)? : s * f (x / s) := --- vconditions (cond : 0 < s) --- implementationVars (t : ℝ) --- implementationObjective (hatom.impObj ⟨x, t⟩) --- implementationConstraints --- (c : hatom.impConstr ⟨x / s, t / s⟩) --- (cs : 0 < s) --- solution (t := s * hatom.sol (x / s)) --- solutionEqualsAtom by { --- simp --- sorry --- } --- feasibility --- (c : by { --- dsimp --- simp [mul_div_cancel_left (a := s) (hatom.sol (x / s)) (ne_of_lt cond).symm] --- exact hatom.feasibility (x / s) --- }) --- (cs : cond) --- optimality by { --- have := hatom.optimality (x / s) (t / s) c --- sorry --- } --- vconditionElimination --- (cond : cs) - --- declare_atom perspective [convex] (f : ℝ → ℝ)? (x : ℝ)? (s : ℝ)? : s * f (x / s) := --- vconditions (cond : 0 < s) --- implementationVars (t : ℝ) --- implementationObjective t + 34 --- implementationConstraints --- (cs : 0 < s) --- solution (t := 0) --- solutionEqualsAtom by --- sorry --- feasibility --- (cs : cond) --- optimality by --- sorry --- vconditionElimination --- (cond : cs) - - --- example : Minimization.Solution <| --- optimization (x : ℝ) --- minimize (2 * Real.exp (x / 2)) --- subject to --- h : Real.posOrthCone x := by --- dcp - -end CvxLean diff --git a/CvxLean/Tactic/DCP/AtomLibrary/Fns/Power.lean b/CvxLean/Tactic/DCP/AtomLibrary/Fns/Power.lean index 1d2cbe31..c14b13ce 100644 --- a/CvxLean/Tactic/DCP/AtomLibrary/Fns/Power.lean +++ b/CvxLean/Tactic/DCP/AtomLibrary/Fns/Power.lean @@ -6,6 +6,13 @@ import CvxLean.Tactic.DCP.AtomLibrary.Fns.Add import CvxLean.Tactic.DCP.AtomLibrary.Fns.Sub import CvxLean.Tactic.DCP.AtomLibrary.Fns.Mul +/-! +Some power atoms: +* x ^ 1 (affine). +* x ^ (-1) (convex). +* x ^ (-2) (convex). +-/ + namespace CvxLean open Real @@ -49,7 +56,7 @@ optimality by have hxpos : 0 < x := by cases (lt_or_eq_of_le c3) with | inl h => exact h - | inr h => rw [←h] at c1; linarith + | inr h => rw [← h] at c1; linarith have hypos := lt_of_lt_of_le hxpos hy rw [rpow_neg_one, inv_eq_one_div, div_le_iff hypos] apply le_trans c1 @@ -61,7 +68,7 @@ vconditionElimination rw [soCone_add_sub_two_of_nonneg c2 c3] at c1 cases (lt_or_eq_of_le c3) with | inl h => linarith - | inr h => rw [←h] at c1; linarith) + | inr h => rw [← h] at c1; linarith) -- NOTE(RFM): It is not straightforward to express it in terms of x ^ (-1) and -- x ^ 2 because of vconditions. @@ -97,7 +104,7 @@ feasibility have hxneg2nn := le_of_lt hxneg2 have h1nn : 0 ≤ (1 : ℝ) := by norm_num rw [soCone_add_sub_two_mul_of_nonneg (x ^ (-1)) hxneg2nn h1nn] - rw [←rpow_mul hxnn] + rw [← rpow_mul hxnn] field_simp) (c3 : by dsimp @@ -117,7 +124,7 @@ optimality by have hxpos : 0 < x := by cases (lt_or_eq_of_le c5) with | inl h => exact h - | inr h => rw [←h] at c1; linarith + | inr h => rw [← h] at c1; linarith have hypos := lt_of_lt_of_le hxpos hy have hynn := le_of_lt hypos have hy2pos : 0 < y ^ (2 : ℝ):= by @@ -126,7 +133,7 @@ optimality by have ht₁pos : 0 < t₁ := by cases (lt_or_eq_of_le c4) with | inl h => exact h - | inr h => rw [←h] at c1; linarith + | inr h => rw [← h] at c1; linarith have ht₁2pos : 0 < t₁ ^ (2 : ℝ) := by rw [rpow_two] exact pow_two_pos_of_ne_zero t₁ (ne_of_gt ht₁pos) @@ -134,18 +141,18 @@ optimality by have ht₀pos : 0 < t₀ := by cases (lt_or_eq_of_le c3) with | inl h => exact h - | inr h => rw [←h] at c2; linarith + | inr h => rw [← h] at c2; linarith -- Combine c1 and c2 appropriately to get t₀ ^ (-1) ≤ y ^ 2. have ht₁invx : t₁ ^ (-1 : ℝ) ≤ x := by rw [rpow_neg c4, rpow_one] - rwa [←div_le_iff ht₁pos, ←inv_eq_one_div] at c1 + rwa [← div_le_iff ht₁pos, ← inv_eq_one_div] at c1 have ht₁invy : t₁ ^ (-1 : ℝ) ≤ y := le_trans ht₁invx hy have ht₁invnn := le_of_lt ht₁inv have ht₁neg2y2 : t₁ ^ (-2 : ℝ) ≤ y ^ (2 : ℝ) := by have h := mul_le_mul ht₁invy ht₁invy ht₁invnn hynn - rw [←rpow_add ht₁pos] at h + rw [← rpow_add ht₁pos] at h norm_num at h - rw [←pow_two, ←rpow_two] at h + rw [← pow_two, ← rpow_two] at h exact h have ht₀invt₁neg2 : t₀ ^ (-1 : ℝ) ≤ t₁ ^ (-2 : ℝ) := by rw [mul_one] at c2 @@ -155,8 +162,8 @@ optimality by have ht₀invy2 := le_trans ht₀invt₁neg2 ht₁neg2y2 -- Fit inequality to goal. rw [rpow_neg hynn, inv_eq_one_div, div_le_iff hy2pos] - rw [mul_comm, ←div_le_iff ht₀pos, ←inv_eq_one_div] - rw [←rpow_one t₀, ←rpow_neg c3] + rw [mul_comm, ← div_le_iff ht₀pos, ← inv_eq_one_div] + rw [← rpow_one t₀, ← rpow_neg c3] exact ht₀invy2 vconditionElimination (hx : by @@ -168,7 +175,7 @@ vconditionElimination have hxpos : 0 < x := by cases (lt_or_eq_of_le c5) with | inl h => exact h - | inr h => rw [←h] at c1; linarith + | inr h => rw [← h] at c1; linarith exact lt_of_lt_of_le hxpos hy) end CvxLean diff --git a/CvxLean/Tactic/DCP/AtomLibrary/Fns/QuadOverLin.lean b/CvxLean/Tactic/DCP/AtomLibrary/Fns/QuadOverLin.lean index 450df790..915d4dd2 100644 --- a/CvxLean/Tactic/DCP/AtomLibrary/Fns/QuadOverLin.lean +++ b/CvxLean/Tactic/DCP/AtomLibrary/Fns/QuadOverLin.lean @@ -10,12 +10,16 @@ import CvxLean.Tactic.DCP.AtomLibrary.Fns.Exp import CvxLean.Tactic.DCP.AtomLibrary.Fns.Transpose import CvxLean.Lib.Math.Data.Vec +/-! +Quadratic-over-linear atom (convex). +-/ + namespace CvxLean open Real --- TODO(RFM): generalize to x : Fin n → ℝ --- TODO(RFM): distinguish between nonincreasing and nondecreasing. +-- TODO: generalize to `x : Fin n → ℝ`. +-- TODO: distinguish between nonincreasing and nondecreasing. declare_atom quadOverLin [convex] (x : ℝ)? (y : ℝ)- : x ^ (2 : ℝ) / y := vconditions (hy : 1 / 100000 ≤ y) @@ -34,7 +38,7 @@ feasibility have hx2ynn : 0 ≤ x ^ (2 : ℝ) / y := by rw [rpow_two]; exact div_nonneg (pow_two_nonneg x) hynn rw [soCone_add_sub_two_mul_of_nonneg _ hynn hx2ynn] - rw [mul_div, mul_comm, ←mul_div, div_self (ne_of_gt hypos), mul_one]) + rw [mul_div, mul_comm, ← mul_div, div_self (ne_of_gt hypos), mul_one]) (c2 : by have hynn : 0 ≤ y := by positivity rw [rpow_two] @@ -82,7 +86,7 @@ feasibility intros t i; rw [vec_soCone_apply_to_soCone_add_sub_two_mul] rw [soCone_add_sub_two_mul_of_nonneg _ (hynn i) (hx2ynn i)]; dsimp - rw [mul_div, mul_comm, ←mul_div, div_self (ne_of_gt (hypos i)), mul_one]) + rw [mul_div, mul_comm, ← mul_div, div_self (ne_of_gt (hypos i)), mul_one]) (c2 : by have hynn : 0 ≤ y := fun i => by have hyi := hy i diff --git a/CvxLean/Tactic/DCP/AtomLibrary/Fns/SMul.lean b/CvxLean/Tactic/DCP/AtomLibrary/Fns/SMul.lean index 970170c0..8a49a31f 100644 --- a/CvxLean/Tactic/DCP/AtomLibrary/Fns/SMul.lean +++ b/CvxLean/Tactic/DCP/AtomLibrary/Fns/SMul.lean @@ -1,5 +1,9 @@ import CvxLean.Tactic.DCP.AtomCmd +/-! +Scalar multiplication atoms (affine). +-/ + namespace CvxLean declare_atom Nat.smul [affine] (n : ℕ)& (x : ℝ)+ : @@ -11,7 +15,7 @@ additivity by rw [smul_zero, add_zero, smul_add] optimality by intros y' hy - apply smul_le_smul_of_nonneg hy (Nat.zero_le _) + apply smul_le_smul_of_nonneg_left hy (Nat.zero_le _) declare_atom Nat.Vec.smul [affine] (n : ℕ)& (k : ℕ)& (x : Fin n → ℝ)+ : k • x := @@ -22,7 +26,7 @@ additivity by rw [smul_zero, add_zero, smul_add] optimality by intros y' hy i - exact smul_le_smul_of_nonneg (hy i) (Nat.zero_le _) + exact smul_le_smul_of_nonneg_left (hy i) (Nat.zero_le _) declare_atom Real.Vec.smul1 [affine] (n : ℕ)& (k : ℝ)& (x : Fin n → ℝ)+ : k • x := @@ -34,7 +38,7 @@ additivity by rw [smul_zero, add_zero, smul_add] optimality by intros y' hy - apply smul_le_smul_of_nonneg hy hk + apply smul_le_smul_of_nonneg_left hy hk declare_atom Real.Vec.smul2 [affine] (n : ℕ)& (k : ℝ)& (x : Fin n → ℝ)- : k • x := @@ -46,7 +50,7 @@ additivity by rw [smul_zero, add_zero, smul_add] optimality by intros y' hy - apply smul_le_smul_of_nonpos hy hk + apply smul_le_smul_of_nonpos_left hy hk declare_atom Real.Vec.smul3 [affine] (n : ℕ)& (k : ℝ)& (x : Fin n → ℝ)? : k • x := diff --git a/CvxLean/Tactic/DCP/AtomLibrary/Fns/Sq.lean b/CvxLean/Tactic/DCP/AtomLibrary/Fns/Sq.lean index 4c4be25c..18cf21f2 100644 --- a/CvxLean/Tactic/DCP/AtomLibrary/Fns/Sq.lean +++ b/CvxLean/Tactic/DCP/AtomLibrary/Fns/Sq.lean @@ -3,6 +3,12 @@ import CvxLean.Tactic.DCP.AtomLibrary.Sets.Cones import CvxLean.Tactic.DCP.AtomLibrary.Fns.VecCons import CvxLean.Tactic.DCP.AtomLibrary.Fns.VecToMatrix +/-! +Square atom (convex). + +Usually just considered as part of the power atom. +-/ + namespace CvxLean open Real diff --git a/CvxLean/Tactic/DCP/AtomLibrary/Fns/Sqrt.lean b/CvxLean/Tactic/DCP/AtomLibrary/Fns/Sqrt.lean index 939fcb5d..3d146175 100644 --- a/CvxLean/Tactic/DCP/AtomLibrary/Fns/Sqrt.lean +++ b/CvxLean/Tactic/DCP/AtomLibrary/Fns/Sqrt.lean @@ -2,6 +2,12 @@ import CvxLean.Tactic.DCP.AtomCmd import CvxLean.Tactic.DCP.AtomLibrary.Sets.Cones import CvxLean.Tactic.DCP.AtomLibrary.Fns.VecCons +/-! +Square root atom (concave). + +Usually just considered as part of the power atom. +-/ + namespace CvxLean open Real diff --git a/CvxLean/Tactic/DCP/AtomLibrary/Fns/Sub.lean b/CvxLean/Tactic/DCP/AtomLibrary/Fns/Sub.lean index 23a00b75..175b9ef5 100644 --- a/CvxLean/Tactic/DCP/AtomLibrary/Fns/Sub.lean +++ b/CvxLean/Tactic/DCP/AtomLibrary/Fns/Sub.lean @@ -1,6 +1,10 @@ import CvxLean.Tactic.DCP.AtomCmd import CvxLean.Lib.Math.Data.Matrix +/-! +Subtraction atom (affine). +-/ + namespace CvxLean declare_atom sub [affine] (x : ℝ)+ (y : ℝ)- : x - y := diff --git a/CvxLean/Tactic/DCP/AtomLibrary/Fns/Sum.lean b/CvxLean/Tactic/DCP/AtomLibrary/Fns/Sum.lean index e2c2dd99..9f62a577 100644 --- a/CvxLean/Tactic/DCP/AtomLibrary/Fns/Sum.lean +++ b/CvxLean/Tactic/DCP/AtomLibrary/Fns/Sum.lean @@ -2,6 +2,10 @@ import CvxLean.Tactic.DCP.AtomCmd import CvxLean.Lib.Math.Data.Vec import CvxLean.Lib.Math.Data.Matrix +/-! +Sum atom (affine). +-/ + namespace CvxLean open Vec in diff --git a/CvxLean/Tactic/DCP/AtomLibrary/Fns/ToUpperTri.lean b/CvxLean/Tactic/DCP/AtomLibrary/Fns/ToUpperTri.lean index bb5cfa0f..339781c4 100644 --- a/CvxLean/Tactic/DCP/AtomLibrary/Fns/ToUpperTri.lean +++ b/CvxLean/Tactic/DCP/AtomLibrary/Fns/ToUpperTri.lean @@ -2,6 +2,10 @@ import CvxLean.Tactic.DCP.AtomCmd import CvxLean.Lib.Math.Data.Matrix import CvxLean.Lib.Math.LinearAlgebra.Matrix.ToUpperTri +/-! +Atom for the operation that returns the upper triangular part of a matrix (affine). +-/ + namespace CvxLean declare_atom Matrix.toUpperTri [affine] (n : ℕ)& (A : Matrix.{0,0,0} (Fin n) (Fin n) ℝ)+ : @@ -14,6 +18,6 @@ additivity by rw [Matrix.toUpperTri_zero, add_zero, Matrix.toUpperTri_add] optimality by intros A' hA i j - by_cases h : i ≤ j <;> simp [Matrix.toUpperTri, h] ; exact hA i j + by_cases h : i ≤ j <;> simp [Matrix.toUpperTri, h]; exact hA i j end CvxLean diff --git a/CvxLean/Tactic/DCP/AtomLibrary/Fns/Trace.lean b/CvxLean/Tactic/DCP/AtomLibrary/Fns/Trace.lean index b8ea6dc6..18633cdc 100644 --- a/CvxLean/Tactic/DCP/AtomLibrary/Fns/Trace.lean +++ b/CvxLean/Tactic/DCP/AtomLibrary/Fns/Trace.lean @@ -1,6 +1,10 @@ import CvxLean.Tactic.DCP.AtomCmd import CvxLean.Lib.Math.Data.Matrix +/-! +Trace atom (affine). +-/ + namespace CvxLean declare_atom Matrix.trace [affine] (m : Type)& (hm : Fintype.{0} m)& diff --git a/CvxLean/Tactic/DCP/AtomLibrary/Fns/Transpose.lean b/CvxLean/Tactic/DCP/AtomLibrary/Fns/Transpose.lean index 38c88f81..80a7771c 100644 --- a/CvxLean/Tactic/DCP/AtomLibrary/Fns/Transpose.lean +++ b/CvxLean/Tactic/DCP/AtomLibrary/Fns/Transpose.lean @@ -1,6 +1,10 @@ import CvxLean.Tactic.DCP.AtomCmd import CvxLean.Lib.Math.Data.Matrix +/-! +Transpose atom (affine). +-/ + namespace CvxLean declare_atom Matrix.transpose [affine] (n : Type)& (m : Type)& (A : Matrix.{0,0,0} n m ℝ)+ : diff --git a/CvxLean/Tactic/DCP/AtomLibrary/Fns/VecCons.lean b/CvxLean/Tactic/DCP/AtomLibrary/Fns/VecCons.lean index e045b27a..5082f420 100644 --- a/CvxLean/Tactic/DCP/AtomLibrary/Fns/VecCons.lean +++ b/CvxLean/Tactic/DCP/AtomLibrary/Fns/VecCons.lean @@ -1,6 +1,10 @@ import CvxLean.Tactic.DCP.AtomCmd import CvxLean.Lib.Math.Data.Matrix +/-! +Vector and matrix cons-ing atoms (affine). +-/ + namespace CvxLean open Matrix @@ -15,8 +19,8 @@ additivity by optimality by intros x' y' hx hy i refine' Fin.cases _ _ i <;> simp [vecCons] - . exact hx - . exact hy + · exact hx + · exact hy declare_atom Matrix.vecCons [affine] (n : Nat)& (m : Nat)& (x : Fin n → ℝ)+ (y : (Fin m) → (Fin n) → ℝ)+ : vecCons x y := @@ -28,7 +32,7 @@ additivity by optimality by intros x' y' hx hy i refine' Fin.cases _ _ i <;> simp [vecCons] - . exact hx - . exact hy + · exact hx + · exact hy end CvxLean diff --git a/CvxLean/Tactic/DCP/AtomLibrary/Fns/VecConst.lean b/CvxLean/Tactic/DCP/AtomLibrary/Fns/VecConst.lean index 5eda8ac8..77e0697d 100644 --- a/CvxLean/Tactic/DCP/AtomLibrary/Fns/VecConst.lean +++ b/CvxLean/Tactic/DCP/AtomLibrary/Fns/VecConst.lean @@ -1,6 +1,10 @@ import CvxLean.Tactic.DCP.AtomCmd import CvxLean.Lib.Math.Data.Vec +/-! +Atom for making a constant vector (affine). +-/ + namespace CvxLean open Real diff --git a/CvxLean/Tactic/DCP/AtomLibrary/Fns/VecToMatrix.lean b/CvxLean/Tactic/DCP/AtomLibrary/Fns/VecToMatrix.lean index 5a1dc6da..69a38d7a 100644 --- a/CvxLean/Tactic/DCP/AtomLibrary/Fns/VecToMatrix.lean +++ b/CvxLean/Tactic/DCP/AtomLibrary/Fns/VecToMatrix.lean @@ -2,6 +2,10 @@ import CvxLean.Tactic.DCP.AtomCmd import CvxLean.Lib.Math.Data.Vec import CvxLean.Lib.Math.Data.Matrix +/-! +Atom for converting a vector to a matrix (affine). +-/ + namespace CvxLean open Real diff --git a/CvxLean/Tactic/DCP/AtomLibrary/Fns/XExp.lean b/CvxLean/Tactic/DCP/AtomLibrary/Fns/XExp.lean index f7a317b7..9a71b4f6 100644 --- a/CvxLean/Tactic/DCP/AtomLibrary/Fns/XExp.lean +++ b/CvxLean/Tactic/DCP/AtomLibrary/Fns/XExp.lean @@ -2,6 +2,10 @@ import CvxLean.Tactic.DCP.AtomLibrary.Sets.Cones import CvxLean.Tactic.DCP.AtomLibrary.Sets.Le import CvxLean.Tactic.DCP.AtomLibrary.Fns.Sq +/-! +Atom for `x * exp(x)` (convex). +-/ + namespace CvxLean open Real @@ -20,15 +24,15 @@ feasibility (c1 : by unfold expCone by_cases h : 0 < x - . left + · left refine ⟨h, ?_⟩ - rw [rpow_two, pow_two, ←mul_div, div_self (ne_of_lt h).symm, mul_one] - . right + rw [rpow_two, pow_two, ← mul_div, div_self (ne_of_lt h).symm, mul_one] + · right replace h := le_of_not_lt h have hxeq0 := le_antisymm h cond simp [hxeq0]) (c2 : by norm_num) - (c3 : by simp [posOrthCone, cond]) + (c3 : by simp [nonnegOrthCone, cond]) optimality by { unfold expCone at c1 simp at c2 diff --git a/CvxLean/Tactic/DCP/AtomLibrary/Sets/All.lean b/CvxLean/Tactic/DCP/AtomLibrary/Sets/All.lean index 962a6a14..90e35421 100644 --- a/CvxLean/Tactic/DCP/AtomLibrary/Sets/All.lean +++ b/CvxLean/Tactic/DCP/AtomLibrary/Sets/All.lean @@ -1,3 +1,4 @@ import CvxLean.Tactic.DCP.AtomLibrary.Sets.Cones import CvxLean.Tactic.DCP.AtomLibrary.Sets.Eq import CvxLean.Tactic.DCP.AtomLibrary.Sets.Le +import CvxLean.Tactic.DCP.AtomLibrary.Sets.PosSemidef diff --git a/CvxLean/Tactic/DCP/AtomLibrary/Sets/Cones.lean b/CvxLean/Tactic/DCP/AtomLibrary/Sets/Cones.lean index 0a147e37..1d6202c4 100644 --- a/CvxLean/Tactic/DCP/AtomLibrary/Sets/Cones.lean +++ b/CvxLean/Tactic/DCP/AtomLibrary/Sets/Cones.lean @@ -2,6 +2,10 @@ import CvxLean.Tactic.DCP.AtomCmd import CvxLean.Lib.Cones.All import CvxLean.Syntax.Minimization +/-! +Cone atoms. They are treated as convex sets with the implementation being themselves. +-/ + namespace CvxLean open Real @@ -15,25 +19,25 @@ optimality by end ZeroCone -/- Positive orthant cone atoms. -/ -section PosOrthCone +/- Nonnegative orthant cone atoms. -/ +section NonnegOrthCone -declare_atom posOrthCone [cone] (x : ℝ)? : posOrthCone x := +declare_atom nonnegOrthCone [cone] (x : ℝ)? : nonnegOrthCone x := optimality by - simp [posOrthCone] + simp [nonnegOrthCone] -declare_atom Vec.posOrthCone [cone] - (n : Nat)& (x : (Fin n) → ℝ)? : Vec.posOrthCone x := +declare_atom Vec.nonnegOrthCone [cone] + (n : Nat)& (x : (Fin n) → ℝ)? : Vec.nonnegOrthCone x := optimality by - simp [Vec.posOrthCone] + simp [Vec.nonnegOrthCone] -declare_atom Matrix.posOrthCone [cone] +declare_atom Matrix.nonnegOrthCone [cone] (m : Nat)& (n : Nat)& (M : Matrix.{0,0,0} (Fin m) (Fin n) ℝ)? : - Real.Matrix.posOrthCone M := + Real.Matrix.nonnegOrthCone M := optimality by - simp [Matrix.posOrthCone] + simp [Matrix.nonnegOrthCone] -end PosOrthCone +end NonnegOrthCone /- Exponential cone atoms. -/ section ExpCone @@ -77,7 +81,7 @@ optimality le_refl _ end PSDCone -/- Trivial cone. -/ +/- Trivial cone atom. -/ section Trivial declare_atom trivial [cone] : True := diff --git a/CvxLean/Tactic/DCP/AtomLibrary/Sets/Eq.lean b/CvxLean/Tactic/DCP/AtomLibrary/Sets/Eq.lean index 27a0dcef..912a864b 100644 --- a/CvxLean/Tactic/DCP/AtomLibrary/Sets/Eq.lean +++ b/CvxLean/Tactic/DCP/AtomLibrary/Sets/Eq.lean @@ -2,6 +2,12 @@ import CvxLean.Tactic.DCP.AtomCmd import CvxLean.Tactic.DCP.AtomLibrary.Sets.Cones import CvxLean.Tactic.DCP.AtomLibrary.Fns.Sub +/-! +Equality atom (convex set). + +Note that we do not distinguish between affine and convex sets. +-/ + namespace CvxLean declare_atom eq [convex_set] (x : ℝ)? (y : ℝ)? : x = y := diff --git a/CvxLean/Tactic/DCP/AtomLibrary/Sets/Le.lean b/CvxLean/Tactic/DCP/AtomLibrary/Sets/Le.lean index d9231129..82d9946b 100644 --- a/CvxLean/Tactic/DCP/AtomLibrary/Sets/Le.lean +++ b/CvxLean/Tactic/DCP/AtomLibrary/Sets/Le.lean @@ -2,21 +2,27 @@ import CvxLean.Tactic.DCP.AtomCmd import CvxLean.Tactic.DCP.AtomLibrary.Sets.Cones import CvxLean.Tactic.DCP.AtomLibrary.Fns.Sub +/-! +Inequality atoms (convex set). + +Note that we do not distinguish between affine and convex sets. +-/ + namespace CvxLean declare_atom le [convex_set] (x : ℝ)- (y : ℝ)+ : x ≤ y := vconditions implementationVars -implementationObjective Real.posOrthCone (y - x) +implementationObjective Real.nonnegOrthCone (y - x) implementationConstraints solution solutionEqualsAtom by - unfold Real.posOrthCone + unfold Real.nonnegOrthCone simp feasibility optimality by intros x' y' hx hy h - unfold Real.posOrthCone at h + unfold Real.nonnegOrthCone at h simp at h exact (hx.trans h).trans hy vconditionElimination @@ -24,24 +30,24 @@ vconditionElimination declare_atom Vec.le [convex_set] (n : Nat)& (x : (Fin n) → ℝ)- (y : (Fin n) → ℝ)+ : x ≤ y := vconditions implementationVars -implementationObjective Real.Vec.posOrthCone (y - x : (Fin n) → ℝ) +implementationObjective Real.Vec.nonnegOrthCone (y - x : (Fin n) → ℝ) implementationConstraints solution solutionEqualsAtom by - unfold Real.Vec.posOrthCone - unfold Real.posOrthCone + unfold Real.Vec.nonnegOrthCone + unfold Real.nonnegOrthCone rw [← iff_iff_eq] constructor · intros h i rw [← le.solEqAtom] - unfold Real.posOrthCone + unfold Real.nonnegOrthCone apply h · intros hy i simp [hy i] feasibility optimality by intros x' y' hx hy h i - unfold Real.Vec.posOrthCone at h + unfold Real.Vec.nonnegOrthCone at h exact le.optimality _ _ _ _ (hx i) (hy i) (h i) vconditionElimination @@ -49,24 +55,24 @@ declare_atom Vec.le' [convex_set] (n : Nat)& (x : (Fin n) → ℝ)- (y : (Fin n) ∀ i, x i ≤ y i := vconditions implementationVars -implementationObjective Real.Vec.posOrthCone (y - x : (Fin n) → ℝ) +implementationObjective Real.Vec.nonnegOrthCone (y - x : (Fin n) → ℝ) implementationConstraints solution solutionEqualsAtom by - unfold Real.Vec.posOrthCone - unfold Real.posOrthCone + unfold Real.Vec.nonnegOrthCone + unfold Real.nonnegOrthCone rw [← iff_iff_eq] constructor · intros h i rw [← le.solEqAtom] - unfold Real.posOrthCone + unfold Real.nonnegOrthCone apply h · intros hy i simp [hy i] feasibility optimality by intros x' y' hx hy h i - unfold Real.Vec.posOrthCone at h + unfold Real.Vec.nonnegOrthCone at h exact le.optimality _ _ _ _ (hx i) (hy i) (h i) vconditionElimination diff --git a/CvxLean/Tactic/DCP/AtomLibrary/Fns/PosSemidef.lean b/CvxLean/Tactic/DCP/AtomLibrary/Sets/PosSemidef.lean similarity index 57% rename from CvxLean/Tactic/DCP/AtomLibrary/Fns/PosSemidef.lean rename to CvxLean/Tactic/DCP/AtomLibrary/Sets/PosSemidef.lean index 55cb3e97..453b0c25 100644 --- a/CvxLean/Tactic/DCP/AtomLibrary/Fns/PosSemidef.lean +++ b/CvxLean/Tactic/DCP/AtomLibrary/Sets/PosSemidef.lean @@ -1,11 +1,16 @@ import CvxLean.Tactic.DCP.AtomCmd import CvxLean.Tactic.DCP.AtomLibrary.Sets.Cones +/-! +Atom for positive semidefinite constraints (convex set). + +This can be seen as an alias for `Matrix.PSDCone`, really. +-/ + namespace CvxLean -declare_atom Matrix.PosSemidef [convex_set] - (m : Type)& (hm : Fintype.{0} m)& (A : Matrix.{0,0,0} m m ℝ)? - : Matrix.PosSemidef A := +declare_atom Matrix.PosSemidef [convex_set] (m : Type)& (hm : Fintype.{0} m)& + (A : Matrix.{0,0,0} m m ℝ)? : Matrix.PosSemidef A := vconditions implementationVars implementationObjective Real.Matrix.PSDCone A diff --git a/CvxLean/Tactic/DCP/AtomSyntax.lean b/CvxLean/Tactic/DCP/AtomSyntax.lean index 01997b20..34881570 100644 --- a/CvxLean/Tactic/DCP/AtomSyntax.lean +++ b/CvxLean/Tactic/DCP/AtomSyntax.lean @@ -1,49 +1,62 @@ import Lean +/-! +# Atom declaration syntax + +This file defines the `declare_atom` syntax. This includes the syntax for specifying the atom's +signature: name, curvature, and arguments with their kind (+, -, ?, &), and the atom's expression. +It also defines the syntax for the rest of the fields needed in an atom declaration. There are +different versions of the command, some of which do not require all fields. +-/ + namespace CvxLean namespace Parser -open Lean.Parser +open Lean Parser -def implementationObjective := leading_parser withPosition $ - nonReservedSymbol "implementationObjective" >> +def implementationObjective := leading_parser withPosition <| + nonReservedSymbol "implementationObjective" >> checkColGt "indented term" >> termParser -def solutionEqualsAtom := leading_parser withPosition $ - nonReservedSymbol "solutionEqualsAtom" >> +def solutionEqualsAtom := leading_parser withPosition <| + nonReservedSymbol "solutionEqualsAtom" >> checkColGt "indented term" >> termParser -def optimality := leading_parser withPosition $ - nonReservedSymbol "optimality" >> +def optimality := leading_parser withPosition <| + nonReservedSymbol "optimality" >> checkColGt "indented term" >> termParser -def homogenity := leading_parser withPosition $ - nonReservedSymbol "homogenity" >> +def homogenity := leading_parser withPosition <| + nonReservedSymbol "homogenity" >> checkColGt "indented term" >> termParser -def additivity := leading_parser withPosition $ - nonReservedSymbol "additivity" >> +def additivity := leading_parser withPosition <| + nonReservedSymbol "additivity" >> checkColGt "indented term" >> termParser end Parser declare_syntax_cat id_with_type + syntax "(" ident ":" term ")" : id_with_type declare_syntax_cat id_with_def + syntax "(" ident ":=" term ")" : id_with_def declare_syntax_cat arg_kind + syntax "+" : arg_kind syntax "-" : arg_kind syntax "?" : arg_kind syntax "&" : arg_kind declare_syntax_cat arg_with_kind + syntax "(" ident ":" term ")" arg_kind : arg_with_kind -syntax (name:=atomCommand) +syntax (name := atomCommand) "declare_atom" ident "[" ident "]" arg_with_kind* ":" term ":=" &"vconditions" id_with_type* &"implementationVars" id_with_type* @@ -55,7 +68,7 @@ syntax (name:=atomCommand) Parser.optimality &"vconditionElimination" id_with_type* : command -syntax (name:=atomWithBCondsCommand) +syntax (name := atomWithBCondsCommand) "declare_atom" ident "[" ident "]" arg_with_kind* ":" term ":=" &"bconditions" id_with_type* &"vconditions" id_with_type* @@ -68,18 +81,15 @@ syntax (name:=atomWithBCondsCommand) Parser.optimality &"vconditionElimination" id_with_type* : command -syntax (name:=affineAtomCommand) +syntax (name := affineAtomCommand) "declare_atom" ident "[" "affine" "]" arg_with_kind* ":" term ":=" &"bconditions" id_with_type* Parser.homogenity - Parser.additivity + Parser.additivity Parser.optimality : command -syntax (name:=coneAtomCommand) +syntax (name := coneAtomCommand) "declare_atom" ident "[" "cone" "]" arg_with_kind* ":" term ":=" Parser.optimality : command end CvxLean - - - diff --git a/CvxLean/Tactic/DCP/DCP.lean b/CvxLean/Tactic/DCP/DCP.lean index 99561821..522fbc22 100644 --- a/CvxLean/Tactic/DCP/DCP.lean +++ b/CvxLean/Tactic/DCP/DCP.lean @@ -6,569 +6,42 @@ import CvxLean.Meta.TacticBuilder import CvxLean.Lib.Math.Data.Array import CvxLean.Lib.Minimization import CvxLean.Lib.Equivalence -import CvxLean.Tactic.DCP.Tree -import CvxLean.Tactic.DCP.OC -import CvxLean.Tactic.DCP.AtomExt +import CvxLean.Tactic.DCP.DCPTypes +import CvxLean.Tactic.DCP.DCPMakers import CvxLean.Tactic.Arith.Arith +/-! +# Main DCP algorithm + +This file contains the main DCP algorithm (`canonize`), it can also be used as a tactic (`dcp`). +It uses the processed atom tree from `mkProcessedAtomTree` in `CvxLean/Tactic/DCP/DCPMakers.lean` to +build the forward and backward maps and prove the four properties of a `StrongEquivalence`. +-/ + namespace CvxLean open Lean Expr Meta namespace DCP -abbrev GraphAtomDataTree := Tree GraphAtomData Expr - -abbrev Argument := Expr -abbrev Arguments := Array Argument -abbrev ArgumentsTree := Tree Arguments Unit - -abbrev CurvatureTree := Tree Curvature Curvature - -abbrev BCond := Expr -abbrev BConds := Array BCond -abbrev BCondsTree := Tree BConds BConds - -abbrev AtomDataTrees := GraphAtomDataTree × ArgumentsTree × CurvatureTree × BCondsTree - -abbrev NewVarsTree := Tree (Array LocalDecl) Unit -abbrev NewConstrVarsTree := Tree (Array LocalDecl) Unit - -/-- Return list of all registered atoms that match with a given expression. -For every registered atom, it returns: -* The list of arguments as Lean expressions for further matching. -* The atom entry, of type `GraphAtomData`. -/ -def findRegisteredAtoms (e : Expr) : MetaM (Array (Arguments × GraphAtomData)) := do - let discrTree ← getAtomDiscrTree - let atoms ← discrTree.getMatch (← zetaReduce (← e.removeMData)) - trace[Meta.debug] "discrTree {atoms.size} {e}" - let mut goodAtoms := #[] - for atom in atoms do - let (mvars, _, pattern) ← lambdaMetaTelescope atom.expr - if ← isDefEq pattern e then - let args ← mvars.mapM instantiateMVars - goodAtoms := goodAtoms.push (args, atom.graph!) - else - trace[Meta.debug] "Pattern did not match. (Pattern {toString pattern}; Expression {toString e})" - -- Heuristic sorting of potential atoms to use: larger expressions have priority. - goodAtoms := goodAtoms.insertionSort (fun a b => (a.2.expr.size - b.2.expr.size != 0)) - return goodAtoms - -/-- Data type used to indicate whether an atom could be successfully -processed. It is used by `processAtom`. A successful match includes 4 trees: -* A tree of atom data (see `GraphAtomData`). -* A tree of arguments for every node. -* A tree of curvatures. -* A tree of background conditions. --/ -inductive FindAtomResult - | Success (res : AtomDataTrees) - | Error (msgs : Array MessageData) - -/-- Given an expression `e`, optimization variables `vars` and the expected -curvature `curvature`, this function attempts to recursively match `e` with -atoms for the library outputing all the necessary information as explained -in `FindAtomResult`. The boolean in the output indicates whether a match from -the library was used. -/ -partial def findAtoms (e : Expr) (vars : Array FVarId) (curvature : Curvature) : - MetaM ( - Bool × - Array MessageData × - AtomDataTrees) := do - let e := eta e - -- Variable case. - if e.isFVar ∧ vars.contains e.fvarId! then - trace[Meta.debug] "Variable {e}" - return (false, #[], Tree.leaf e, Tree.leaf (), Tree.leaf curvature, Tree.leaf #[]) - let potentialAtoms ← findRegisteredAtoms e - -- Constant case. - let isConstantExpr := e.isRelativelyConstant vars || curvature == Curvature.Constant - let isPropExpr := (← inferType e).isProp - trace[Meta.debug] "Constant? {e} {vars.map (mkFVar ·)} {isConstantExpr}" - if isConstantExpr ∧ !isPropExpr then - return (false, #[], Tree.leaf e, Tree.leaf (), Tree.leaf curvature, Tree.leaf #[]) - let mut failedAtoms : Array MessageData := #[] - if potentialAtoms.size == 0 then - failedAtoms := failedAtoms.push m!"No atom found for {e}" - for (args, atom) in potentialAtoms do - match ← processAtom e vars curvature atom args with - | FindAtomResult.Success (atoms, args, curvatures, bconds) => - trace[Meta.debug] "Success {e} : {atom}" - return (false, failedAtoms, atoms, args, curvatures, bconds) - | FindAtomResult.Error msg => - failedAtoms := failedAtoms ++ msg - return (true, failedAtoms, Tree.leaf e, Tree.leaf (), Tree.leaf curvature, Tree.leaf #[]) -where - processAtom (e : Expr) (vars : Array FVarId) (curvature : Curvature) (atom : GraphAtomData) (args : Array Expr) : MetaM FindAtomResult := do - trace[Meta.debug] "Processing atom {atom.expr} for expression {e}, curvature {curvature}" - -- TODO: is this correct? - if atom.curvature != Curvature.Affine ∧ curvature != atom.curvature then - return FindAtomResult.Error - #[m!"Trying atom {atom.expr} for expression {e}: " ++ - m!"Expected {curvature}, but atom is {atom.curvature}"] - let mut abort := false -- TODO: use exception instead? - let mut bconds := #[] - -- withExistingLocalDecls originalConstrVars.toList do - for (bcondName, bcondType) in atom.bconds do - let bcondType := mkAppNBeta bcondType args - let fvarId? ← (← getLCtx).decls.findSomeM? fun decl => match decl with - | none => pure none - | some decl => do - if ← isDefEq decl.type bcondType then return some decl.fvarId else return none - match fvarId? with - | some fvarId => - bconds := bconds.push (mkFVar fvarId) - | none => - -- Try to prove simple bconditions by arith. - let (e, _) ← Lean.Elab.Term.TermElabM.run $ Lean.Elab.Term.commitIfNoErrors? $ do - let v ← Lean.Elab.Term.elabTerm (← `(by arith)).raw (some bcondType) - Lean.Elab.Term.synthesizeSyntheticMVarsNoPostponing - let v ← instantiateMVars v - return v - if let some e' := e then - bconds := bconds.push e' - else - let declsTypes := (← getLCtx).decls.toList.filterMap (Option.map LocalDecl.type) - return FindAtomResult.Error - #[m!"Trying atom {atom.expr} for expression {e}: " ++ - m!"Background Condition {bcondType} not found. (Local context: {declsTypes})"] - - let mut childTrees := #[] - let mut childArgsTrees := #[] - let mut childCurvatures := #[] - let mut childBConds := #[] - for i in [:args.size] do - let arg := args[i]! - if atom.argKinds[i]! == ArgKind.Constant ∧ not (arg.isRelativelyConstant vars) then - return FindAtomResult.Error - #[m!"Trying atom {atom.expr} for expression {e}: " ++ - m!"Expected constant argument, but found: {arg}"] - let c := curvatureInArg curvature atom.argKinds[i]! - trace[Meta.debug] "Trying to find atoms for {arg} with curvature {c}" - trace[Meta.debug] "vars: {vars.map (mkFVar ·)}" - trace[Meta.debug] "args: {args}" - let (childFailed, childFailedAtoms, childTree, childArgsTree, childCurvature, childBCond) ← findAtoms arg vars c - if childFailed then - return FindAtomResult.Error childFailedAtoms - trace[Meta.debug] "Found {childTree}" - childTrees := childTrees.push childTree - childArgsTrees := childArgsTrees.push childArgsTree - childCurvatures := childCurvatures.push childCurvature - childBConds := childBConds.push childBCond - if curvature == Curvature.Affine then - trace[Meta.debug] "Affine branch for {e}" - -- For affine branches, we only check that all atoms are affine, but don't store the results. - return FindAtomResult.Success (Tree.leaf e, Tree.leaf (), Tree.leaf curvature, Tree.leaf bconds) - else - return FindAtomResult.Success (Tree.node atom childTrees, Tree.node args childArgsTrees, Tree.node curvature childCurvatures, Tree.node bconds childBConds) - -/-- Returns number of added variables -/ -partial def mkNewVars (i : Nat) : GraphAtomDataTree → ArgumentsTree → MetaM (NewVarsTree × Nat) -| Tree.node atom childAtoms, Tree.node args childArgs => do - let mut childNewVars := #[] - let mut i := i - for k in [:childAtoms.size] do - let (childNewVar, i') ← mkNewVars i childAtoms[k]! childArgs[k]! - i := i' - childNewVars := childNewVars.push $ childNewVar - let mut newVars : Array LocalDecl := #[] - for (n, ty) in atom.impVars do - trace[Meta.debug] "new var type: {ty}" - newVars := newVars.push $ - LocalDecl.cdecl 0 (← mkFreshFVarId) (Name.mkNum n i) (mkAppNBeta ty args) Lean.BinderInfo.default LocalDeclKind.default - i := i + 1 - return (Tree.node newVars childNewVars, i) -| Tree.leaf _, Tree.leaf _ => pure (Tree.leaf (), i) -| _, _ => throwError "Tree mismatch" - -/-- -/ -def mkNewVarDeclList (newVars : OC NewVarsTree) : - MetaM (List LocalDecl) := do - let newVarTrees := #[newVars.objFun] ++ newVars.constr - let mut newVarDecls := #[] - for t in newVarTrees do - newVarDecls := t.fold newVarDecls Array.append - return newVarDecls.toList - -abbrev CanonExpr := Expr -abbrev CanonExprsTree := Tree CanonExpr CanonExpr - -/-- -/ -partial def mkCanonExprs : GraphAtomDataTree → NewVarsTree → MetaM (Tree Expr Expr) - | Tree.node atom childAtoms, Tree.node newVars childNewVars => do - let mut childCanonExprs := #[] - for i in [:childAtoms.size] do - childCanonExprs := childCanonExprs.push $ ← mkCanonExprs childAtoms[i]! childNewVars[i]! - let canonExpr := atom.impObjFun - let canonExpr := mkAppNBeta canonExpr (childCanonExprs.map (·.val)) - let canonExpr := mkAppNBeta canonExpr (newVars.map (mkFVar ·.fvarId)) - return Tree.node canonExpr childCanonExprs - | Tree.leaf e, Tree.leaf () => pure $ Tree.leaf e - | _, _ => throwError "Tree mismatch" - -abbrev NewConstrsTree := Tree (Array Expr) Unit - -abbrev PreVCond := ℕ ⊕ Expr -abbrev PreVConds := Array PreVCond -abbrev PreVCondsTree := Tree PreVConds Unit - -abbrev VCond := Expr -abbrev VConds := Array VCond -abbrev VCondsTree := Tree VConds Unit - -/-- -/ -partial def mkNewConstrs : GraphAtomDataTree → BCondsTree → NewVarsTree → CanonExprsTree → MetaM NewConstrsTree - | Tree.node atom childAtoms, Tree.node bconds childBConds, Tree.node newVars childNewVars, Tree.node canonExprs childCanonExprs => do - let mut childNewConstrs := #[] - for i in [:childAtoms.size] do - childNewConstrs := childNewConstrs.push <| ← mkNewConstrs childAtoms[i]! childBConds[i]! childNewVars[i]! childCanonExprs[i]! - let newConstrs := atom.impConstrs - let newConstrs := newConstrs.map (mkAppNBeta · (childCanonExprs.map Tree.val)) - let newConstrs := newConstrs.map (mkAppNBeta · bconds) - let newConstrs := newConstrs.map (mkAppNBeta · (newVars.map (mkFVar ·.fvarId))) - return Tree.node newConstrs childNewConstrs - | Tree.leaf _, Tree.leaf _, Tree.leaf _, Tree.leaf _ => pure $ Tree.leaf () - | _, _, _, _ => throwError "Tree mismatch" - -/-- Returns index of constraint if vcondition corresponds exactly to a constraint -(this is needed for condition elimination). If not, it tries to deduce the condition -from other constraints and returns an expression. -/ -partial def findVConditions (originalConstrVars : Array LocalDecl) (constraints : Array Expr) : - GraphAtomDataTree → ArgumentsTree → MetaM PreVCondsTree - | Tree.node atom childAtoms, Tree.node args childArgs => do - let mut childrenVCondData := #[] - for i in [:childAtoms.size] do - let childVCondData ← - findVConditions originalConstrVars constraints childAtoms[i]! childArgs[i]! - childrenVCondData := childrenVCondData.push childVCondData - let mut vcondData := #[] - for (n, vcond) in atom.vconds do - let vcond := mkAppNBeta vcond args - - -- First, try to see if it matches exactly with any of the constraints. - match ← constraints.findIdxM? (isDefEq vcond) with - | some i => do vcondData := vcondData.push (Sum.inl i) - | none => - -- Infer vconditions from constraints. - vcondData := ← - withExistingLocalDecls originalConstrVars.toList do - let (e, _) ← Lean.Elab.Term.TermElabM.run <| Lean.Elab.Term.commitIfNoErrors? <| do - let tac ← `(by arith) - let v ← Lean.Elab.Term.elabTerm tac.raw (some vcond) - Lean.Elab.Term.synthesizeSyntheticMVarsNoPostponing - instantiateMVars v - - if let some e' := e then - return vcondData.push (Sum.inr e') - else - throwError - "Variable condition for {atom.id} not found or inferred: \n {vcond} {constraints}." - - return (Tree.node vcondData childrenVCondData) - | Tree.leaf _, Tree.leaf _ => pure (Tree.leaf ()) - | _, _ => throwError "Tree mismatch." - --- One variable -abbrev NewVarAssignment := Expr - -abbrev Solution := Array NewVarAssignment - -structure CanonExprWithSolution where - canonExpr : CanonExpr - solution : Solution - -instance : Inhabited CanonExprWithSolution := ⟨⟨default, default⟩⟩ - -abbrev CanonExprsWithSolutionTree := Tree CanonExprWithSolution CanonExprWithSolution - -/-- Returns the canon expression and an array of forward images of new vars -/ -partial def mkCanonWithSolution : GraphAtomDataTree → MetaM CanonExprsWithSolutionTree - | Tree.node atom childAtoms => do - let mut childCanonExprs := #[] - for i in [:childAtoms.size] do - childCanonExprs := childCanonExprs.push $ ← mkCanonWithSolution childAtoms[i]! - let sols := (atom.solution.map (mkAppNBeta · (childCanonExprs.1.map (·.val.1)).toArray)) - let canonExpr := atom.impObjFun - let canonExpr := mkAppNBeta canonExpr (childCanonExprs.1.map (·.val.1)).toArray - let canonExpr := mkAppNBeta canonExpr sols - return Tree.node ⟨canonExpr, sols⟩ childCanonExprs - | Tree.leaf e => - return Tree.leaf ⟨e, #[]⟩ - -/-- Combine all the solutions introduced by every atom expansion. A solution -tells us how to assign new variables from old ones. Here, we collect all such -assignments, which will be used later on to build the forward map between the -original and canon problems. -/ -def mkForwardImagesNewVars (canonWithSolution : OC CanonExprsWithSolutionTree) : MetaM Solution := do - let canonWithSolutionTrees := #[canonWithSolution.objFun] ++ canonWithSolution.constr - let mut fm := #[] - for t in canonWithSolutionTrees do - fm := t.fold fm (fun acc a => Array.append acc a.solution) - trace[Meta.debug] "forwardImagesNewVars {fm}" - return fm - -abbrev SolEqAtomProof := Expr -abbrev SolEqAtomProofsTree := Tree SolEqAtomProof SolEqAtomProof - -/-- Proofs .-/ -partial def mkSolEqAtom : - GraphAtomDataTree → - CanonExprsWithSolutionTree → - VCondsTree → - MetaM SolEqAtomProofsTree - | Tree.node atom childAtoms, - Tree.node canonWithSolution childCanonWithSolution, - Tree.node vcondVars childVCondVars => do - -- Recursive calls for arguments. - let mut childSolEqAtom := #[] - for i in [:childAtoms.size] do - childSolEqAtom := childSolEqAtom.push <| - ← mkSolEqAtom childAtoms[i]! childCanonWithSolution[i]! childVCondVars[i]! - -- Rewrite arguments in atom expr. - let mut solEqAtomR ← mkEqRefl atom.expr - for c in childSolEqAtom do - solEqAtomR ← mkCongr solEqAtomR c.val - -- Use solEqAtom of children to rewrite the arguments in the vconditions. - let mut vconds := #[] - for i in [:vcondVars.size] do - let mut vcondEqCanonVCond ← mkEqRefl atom.vconds[i]!.2 - for c in childSolEqAtom do - vcondEqCanonVCond ← mkCongr vcondEqCanonVCond c.val - vconds := vconds.push $ ← mkEqMPR vcondEqCanonVCond vcondVars[i]! - -- Apply solEqAtom property of the atom. - let solEqAtomL := atom.solEqAtom - let solEqAtomL := mkAppN solEqAtomL (childCanonWithSolution.map (·.val.1)) - let solEqAtomL := mkAppN solEqAtomL vconds - let solEqAtom ← mkEqTrans solEqAtomL solEqAtomR - return Tree.node solEqAtom childSolEqAtom - | Tree.leaf e, Tree.leaf _, Tree.leaf _ => do - return Tree.leaf (← mkEqRefl e) - | _, _, _ => - throwError "Tree mismatch" - -abbrev FeasibilityProof := Expr -abbrev FeasibilityProofs := Array FeasibilityProof -abbrev FeasibilityProofsTree := Tree FeasibilityProofs Unit - -/-- -/ -partial def mkFeasibility : - GraphAtomDataTree → - CanonExprsWithSolutionTree → - BCondsTree → - VCondsTree → - SolEqAtomProofsTree → - MetaM FeasibilityProofsTree - | Tree.node atom childAtoms, - Tree.node _canonWithSolution childCanonWithSolution, - Tree.node bconds childBConds, - Tree.node vcondVars childVCondVars, - Tree.node _solEqAtom childSolEqAtom => do - -- Recursive calls for arguments. - let mut childFeasibility := #[] - for i in [:childAtoms.size] do - let c ← mkFeasibility childAtoms[i]! childCanonWithSolution[i]! childBConds[i]! childVCondVars[i]! childSolEqAtom[i]! - childFeasibility := childFeasibility.push c - -- Use solEqAtom of children to rewrite the arguments in the vconditions. - let mut vconds := #[] - for i in [:vcondVars.size] do - let mut vcondEqCanonVCond ← mkEqRefl atom.vconds[i]!.2 - for c in childSolEqAtom do - vcondEqCanonVCond ← mkCongr vcondEqCanonVCond c.val - vconds := vconds.push $ ← mkEqMPR vcondEqCanonVCond vcondVars[i]! - -- Apply feasibility property of the atom. - let feasibility := atom.feasibility - let feasibility := feasibility.map (mkAppN · (childCanonWithSolution.map (·.val.1))) - let feasibility := feasibility.map (mkAppN · bconds) - let feasibility := feasibility.map (mkAppN · vconds) - let _ ← feasibility.mapM check - return Tree.node feasibility childFeasibility - | Tree.leaf _, Tree.leaf _, Tree.leaf _, Tree.leaf _, Tree.leaf _ => - return Tree.leaf () - | _, _, _, _, _ => - throwError "Tree mismatch" - -abbrev OptimalityProof := Expr - -abbrev VCondElimProof := Expr -abbrev VCondElimProofs := Array VCondElimProof - -abbrev OptimalityAndVCondElimProofs := OptimalityProof × VCondElimProofs -abbrev OptimalityAndVCondElimProofsTree := Tree OptimalityAndVCondElimProofs OptimalityAndVCondElimProofs - --- TODO: What does this actually return? - --- What we call optimality is really the building blocks of --- the backward map. --- This is in fact the optimality property but with arguments applied to it. - --- so they look like --- cone cond => inequality --- or just --- inequality (usually when it depends on new vars). --- vcond elimination - -/-- -/ -partial def mkOptimalityAndVCondElim : - GraphAtomDataTree → - ArgumentsTree → - CanonExprsTree → - NewVarsTree → - NewConstrVarsTree → - CurvatureTree → - BCondsTree → - MetaM OptimalityAndVCondElimProofsTree - | Tree.node atom childAtoms, - Tree.node args childArgs, - Tree.node _canonExpr childCanonExpr, - Tree.node newVars childNewVars, - Tree.node newConstrVars childNewConstrVars, - Tree.node curvature childCurvature, - Tree.node bconds childBConds => do - -- Recursive calls for arguments. - let mut childOptimality := #[] - let mut childOptimalityFiltered := #[] - for i in [:childAtoms.size] do - if atom.argKinds[i]! != ArgKind.Constant ∧ atom.argKinds[i]! != ArgKind.Neither then - let opt ← mkOptimalityAndVCondElim childAtoms[i]! childArgs[i]! childCanonExpr[i]! - childNewVars[i]! childNewConstrVars[i]! childCurvature[i]! childBConds[i]! - childOptimality := childOptimality.push opt - childOptimalityFiltered := childOptimalityFiltered.push opt - else - childOptimality := childOptimality.push $ Tree.leaf (mkStrLit s!"dummy {args[i]!}", #[]) - let mut monoArgs := #[] - for i in [:args.size] do - if atom.argKinds[i]! != ArgKind.Constant ∧ atom.argKinds[i]! != ArgKind.Neither then - monoArgs := monoArgs.push args[i]! - - -- Apply optimality property of atom. TODO: why these and.left/right? - let optimality ← - match curvature, atom.curvature with - | Curvature.Concave, Curvature.Affine => mkAppM ``And.left #[atom.optimality] - | Curvature.Convex, Curvature.Affine => mkAppM ``And.right #[atom.optimality] - | _, _ => pure atom.optimality - let optimality := mkAppN optimality (childCanonExpr.map Tree.val) - let optimality := mkAppN optimality bconds - check optimality - let optimality := mkAppN optimality (newVars.map (mkFVar ·.fvarId)) - let optimality := mkAppN optimality (newConstrVars.map (mkFVar ·.fvarId)) - let optimality := mkAppN optimality monoArgs - let optimality := mkAppN optimality (childOptimalityFiltered.map (·.val.1)) - check optimality - - -- Apply vcond elim property of atom. - let vcondElim := atom.vcondElim - let vcondElim := vcondElim.map (mkAppN · (childCanonExpr.map Tree.val)) - let vcondElim := vcondElim.map (mkAppN · (newVars.map (mkFVar ·.fvarId))) - let vcondElim := vcondElim.map (mkAppN · (newConstrVars.map (mkFVar ·.fvarId))) - let vcondElim := vcondElim.map (mkAppN · monoArgs) - let vcondElim := vcondElim.map (mkAppN · (childOptimalityFiltered.map (·.val.1))) - - return Tree.node (optimality, vcondElim) childOptimality - | Tree.leaf e, Tree.leaf _, Tree.leaf _, Tree.leaf _, Tree.leaf _, Tree.leaf _, Tree.leaf _ => do - trace[Meta.debug] "Optimality Leaf {e}" - return Tree.leaf (← mkAppM ``le_refl #[e], #[]) - | _, _, _, _, _, _, _ => throwError "Tree mismatch" - -/-- -/ -partial def mkNewConstrVars : - NewConstrsTree → - MetaM NewConstrVarsTree - | Tree.node newConstr childNewConstr => do - -- Recursive calls for arguments - let mut childNewConstrVar := #[] - for i in [:childNewConstr.size] do - childNewConstrVar := childNewConstrVar.push <| - ← mkNewConstrVars childNewConstr[i]! - -- Create variables - let mut newConstrVars : Array LocalDecl := #[] - for ty in newConstr do - newConstrVars := newConstrVars.push $ - LocalDecl.cdecl 0 (← mkFreshFVarId) `h ty Lean.BinderInfo.default LocalDeclKind.default - return Tree.node newConstrVars childNewConstrVar - | Tree.leaf _ => pure $ Tree.leaf () - -/-- -/ -def mkOriginalConstrVars (originalVarsDecls : Array LocalDecl) - (constr : Array (Name × Expr)) : MetaM (Array LocalDecl) := do -withExistingLocalDecls originalVarsDecls.toList do - let mut decls : Array LocalDecl := #[] - for i in [:constr.size] do - decls := decls.push $ LocalDecl.cdecl 0 (← mkFreshFVarId) (Name.mkNum constr[i]!.1 i) constr[i]!.2 Lean.BinderInfo.default LocalDeclKind.default - return decls - -/-- -/ -def mkVCondVars (originalConstrVars : Array FVarId) (vcondIdx : Tree (Array (ℕ ⊕ Expr)) Unit) : Tree (Array Expr) Unit := - vcondIdx.map (fun iores => iores.map fun iore => - match iore with - | Sum.inl i => mkFVar originalConstrVars[i]! - | Sum.inr e => e) id - -/-- -/ -def foldProdMk (exprs : Array Expr) : MetaM Expr := do - if exprs.size == 0 then - throwError "exprs cannot be empty" - let mut res := exprs[exprs.size - 1]! - for i in [1:exprs.size] do - res ← mkAppM ``Prod.mk #[exprs[exprs.size - 1 - i]!, res] - return res - -/-- -/ -def makeForwardMap (oldDomain : Expr) (xs : Array Expr) (forwardImagesNewVars : Array Expr): MetaM Expr := do +/-- Create the forward map (`φ`) from the forward images given by the processed atom tree. -/ +def makeForwardMap (oldDomain : Expr) (xs : Array Expr) (forwardImagesNewVars : Array Expr) : + MetaM Expr := do withLocalDeclD `p oldDomain fun p => do let prs := (← Meta.mkProjections oldDomain p).map (·.2.2) let forwardBody := xs ++ forwardImagesNewVars - let forwardMap ← mkLambdaFVars #[p] $ (← foldProdMk forwardBody).replaceFVars xs prs.toArray - trace[Meta.debug] "forwardMap: {forwardMap}" + let forwardMap ← mkLambdaFVars #[p] <| (← foldProdMk forwardBody).replaceFVars xs prs.toArray + trace[CvxLean.debug] "Forward map: {forwardMap}." check forwardMap - return forwardMap - -/-- -/ -def makeBackwardMap (xs : Array Expr) (mkDomainFunc : Expr → MetaM Expr) : MetaM Expr := do - let backwardBody ← foldProdMk $ xs - let backwardMap ← mkDomainFunc backwardBody - trace[Meta.debug] "backwardMap: {backwardMap}" - check backwardMap - trace[Meta.debug] "backwardMap checked" - return backwardMap + trace[CvxLean.debug] "Forward map checked." -/-- -/ -def makeObjFunForward (oldDomain : Expr) (xs : Array Expr) (originalConstrVars : Array LocalDecl) - (oldProblem : Expr) (constraints : Array Expr) (objFunEq : Expr) : MetaM Expr := do - -- ∀ {x : D}, p.constraints x → q.objFun (f x) ≤ p.objFun x - withLocalDeclD `p oldDomain fun p => do - let prs := (← Meta.mkProjections oldDomain p).map (·.2.2) - - withLocalDeclD `h (← mkAppM ``Minimization.constraints #[oldProblem, p]) fun h => do - let (_, cprs) := Meta.composeAndWithProj constraints.toList - let hProj := (cprs h).toArray - - let objFunForward ← mkAppM ``le_of_eq #[objFunEq.replaceFVars xs prs.toArray] - - let objFunForward := objFunForward.replaceFVars - ((originalConstrVars).map (mkFVar ·.fvarId)) hProj - let objFunForward ← mkLambdaFVars #[h] objFunForward - - let objFunForward := objFunForward.replaceFVars xs prs.toArray - let objFunForward ← mkLambdaFVars #[p] objFunForward - trace[Meta.debug] "objFunForward: {objFunForward}" - check objFunForward - return objFunForward - -/-- -/ -def foldAndIntro (exprs : Array Expr) : MetaM Expr := do - if exprs.size == 0 then - return mkConst ``True.intro - let mut res := exprs[exprs.size - 1]! - for i in [1:exprs.size] do - res ← mkAppM ``And.intro #[exprs[exprs.size - 1 - i]!, res] - return res + return forwardMap -/-- -/ +/-- Build the proof of `StrongEquivalence.phi_feasibility`. This uses the feasibility proofs. -/ def makeConstrForward (oldDomain : Expr) (xs : Array Expr) (originalConstrVars : Array LocalDecl) - (oldProblem : Expr) (constraints : Array Expr) (isVCond : Array Bool) (constraintsEq : Array Expr) - (feasibility : OC (Tree (Array Expr) Unit)) : MetaM Expr := do - -- ∀ {x : D}, Minimization.constraints p x → Minimization.constraints q (f x) - + (oldProblem : Expr) (constraints : Array Expr) (isVCond : Array Bool) + (constraintsEq : Array Expr) (feasibility : OC FeasibilityProofsTree) : MetaM Expr := do + -- `∀ {x : D}, p.feasible x → q.feasible (φ x)`. withLocalDeclD `p oldDomain fun p => do let prs := (← Meta.mkProjections oldDomain p).map (·.2.2) @@ -580,51 +53,67 @@ def makeConstrForward (oldDomain : Expr) (xs : Array Expr) (originalConstrVars : let mut oldConstrProofs := #[] for i in [:originalConstrVars.size] do if not isVCond[i]! then - oldConstrProofs := oldConstrProofs.push $ + oldConstrProofs := oldConstrProofs.push <| ← mkAppM ``Eq.mpr #[constraintsEq[i]!, mkFVar originalConstrVars[i]!.fvarId] -- New constraint proofs. let newConstrProofs := feasibility.fold #[] fun acc fs => fs.fold acc Array.append - let constrForwardBody ← foldAndIntro $ (oldConstrProofs ++ newConstrProofs) + let constrForwardBody ← foldAndIntro (oldConstrProofs ++ newConstrProofs) let constrForwardBody := constrForwardBody.replaceFVars ((originalConstrVars).map (mkFVar ·.fvarId)) hProj let constrForwardBody ← mkLambdaFVars #[h] constrForwardBody let constrForwardBody := constrForwardBody.replaceFVars xs prs.toArray let constrForward ← mkLambdaFVars #[p] constrForwardBody - trace[Meta.debug] "constrForward: {constrForward}" - trace[Meta.debug] "constrForwardType: {← inferType constrForward}" + trace[CvxLean.debug] "φ feasibility proof: {constrForward}." + trace[CvxLean.debug] "φ feasibility proof type: {← inferType constrForward}." check constrForward + return constrForward -/-- -/ -def makeObjFunBackward (newDomain : Expr) (canonProblem : Expr) (xs : Array Expr) (ys : Array Expr) (objFunOpt : Expr) - (canonConstrs : Array Expr) (newConstrs : Array Expr) - (newConstrVars : Array LocalDecl) : MetaM Expr := do - -- ∀ {x : E}, Minimization.constraints q x → Minimization.objFun p (g x) ≤ Minimization.objFun q x +/-- Build a proof of `StrongEquivalence.phi_optimality`. We can actually prove equality here, using +the solution-equal-atom property from the top atom that unfolds the objective function. -/ +def makeObjFunForward (oldDomain : Expr) (xs : Array Expr) (originalConstrVars : Array LocalDecl) + (oldProblem : Expr) (constraints : Array Expr) (objFunEq : Expr) : MetaM Expr := do + -- `∀ {x : D}, p.feasible x → q.objFun (φ x) ≤ p.objFun x`. + withLocalDeclD `p oldDomain fun p => do + let prs := (← Meta.mkProjections oldDomain p).map (·.2.2) - withLocalDeclD `p newDomain fun p => do - let prs := (← Meta.mkProjections newDomain p).map (·.2.2) + withLocalDeclD `h (← mkAppM ``Minimization.constraints #[oldProblem, p]) fun h => do + let (_, cprs) := Meta.composeAndWithProj constraints.toList + let hProj := (cprs h).toArray - withLocalDeclD `h (← mkAppM ``Minimization.constraints #[canonProblem, p]) fun h => do - let (_, cprs) := Meta.composeAndWithProj (canonConstrs ++ newConstrs).toList - let hProj := cprs h - let objFunBackwardBody := objFunOpt - let objFunBackwardBody := objFunBackwardBody.replaceFVars - (newConstrVars.map (mkFVar ·.fvarId)) (hProj.drop (hProj.length - newConstrVars.size)).toArray - let objFunBackwardBody := objFunBackwardBody.replaceFVars - (xs ++ ys) prs.toArray - let objFunBackward ← mkLambdaFVars #[p, h] objFunBackwardBody - trace[Meta.debug] "objFunBackward: {objFunBackward}" - check objFunBackward - return objFunBackward + let objFunForward ← mkAppM ``le_of_eq #[objFunEq.replaceFVars xs prs.toArray] + + let objFunForward := objFunForward.replaceFVars + ((originalConstrVars).map (mkFVar ·.fvarId)) hProj + let objFunForward ← mkLambdaFVars #[h] objFunForward -/-- -/ -def makeConstrBackward (vcondElimMap : Std.HashMap Nat Expr) (newDomain : Expr) (canonProblem : Expr) (xs : Array Expr) (ys : Array Expr) (constrOpt : Array Expr) - (canonConstrs : Array Expr) (newConstrs : Array Expr) (newConstrVars : Array LocalDecl) : MetaM Expr := do - -- ∀ {x : E}, Minimization.constraints q x → Minimization.constraints p (g x) + let objFunForward := objFunForward.replaceFVars xs prs.toArray + let objFunForward ← mkLambdaFVars #[p] objFunForward + trace[CvxLean.debug] "φ optimality proof: {objFunForward}." + trace[CvxLean.debug] "φ optimality proof type: {← inferType objFunForward}." + check objFunForward + + return objFunForward + +/-- Create the backward map (`ψ`), which is simply a projection. -/ +def makeBackwardMap (xs : Array Expr) (mkDomainFunc : Expr → MetaM Expr) : MetaM Expr := do + let backwardBody ← foldProdMk xs + let backwardMap ← mkDomainFunc backwardBody + trace[CvxLean.debug] "Backward map: {backwardMap}." + check backwardMap + trace[CvxLean.debug] "Backward map checked." + + return backwardMap +/-- Build a proof of `StrongEquivalence.psi_feasibility`. The proofs here come from the optimality +proofs and vcondition elimination. -/ +def makeConstrBackward (vcondElimMap : VCondElimMap) (newDomain : Expr) (canonProblem : Expr) + (xs : Array Expr) (ys : Array Expr) (constrOpt : Array Expr) (canonConstrs : Array Expr) + (newConstrs : Array Expr) (newConstrVars : Array LocalDecl) : MetaM Expr := do + -- `∀ {x : E}, q.feasible x → p.feasible (ψ x)`. withLocalDeclD `p newDomain fun p => do let prs := (← Meta.mkProjections newDomain p).map (·.2.2) @@ -635,252 +124,66 @@ def makeConstrBackward (vcondElimMap : Std.HashMap Nat Expr) (newDomain : Expr) let mut filteredCounter := 0 for i in [:constrOpt.size] do match vcondElimMap.find? i with - | some p => - constrBackwardProofs := constrBackwardProofs.push $ p - | none => - constrBackwardProofs := constrBackwardProofs.push $ mkApp constrOpt[i]! hProj.toArray[filteredCounter]! - filteredCounter := filteredCounter + 1 + | some p => + constrBackwardProofs := constrBackwardProofs.push p + | none => + constrBackwardProofs := constrBackwardProofs.push <| + mkApp constrOpt[i]! hProj.toArray[filteredCounter]! + filteredCounter := filteredCounter + 1 let constrBackwardBody ← foldAndIntro constrBackwardProofs - let constrBackwardBody := constrBackwardBody.replaceFVars - (newConstrVars.map (mkFVar ·.fvarId)) (hProj.drop (hProj.length - newConstrVars.size)).toArray + (newConstrVars.map (mkFVar ·.fvarId)) + (hProj.drop (hProj.length - newConstrVars.size)).toArray - let constrBackwardBody := constrBackwardBody.replaceFVars - (xs ++ ys) prs.toArray + let constrBackwardBody := constrBackwardBody.replaceFVars (xs ++ ys) prs.toArray let constrBackward ← mkLambdaFVars #[p, h] constrBackwardBody - trace[Meta.debug] "constrBackward: {constrBackward}" + trace[CvxLean.debug] "ψ feasibility proof: {constrBackward}." + trace[CvxLean.debug] "ψ feasibility proof type: {← inferType constrBackward}." check constrBackward - trace[Meta.debug] "constrBackward checked" + return constrBackward -/-- Extract objective and constraints in terms of the optimization variables. -/ -def mkOC (objFun : Expr) (constraints : List (Lean.Name × Expr)) (originalVarsDecls : Array LocalDecl) - : MetaM (OC Expr) := do - withExistingLocalDecls originalVarsDecls.toList do - let oc := OC.mk objFun (constraints.toArray.map (fun (c : Lean.Name × Expr) => c.2)) - return oc - -/-- Same as `mkOC` but keeping constraint tags. The objective function is tagged -with anonymous. -/ -def mkOCWithNames (objFun : Expr) (constraints : List (Lean.Name × Expr)) (originalVarsDecls : Array LocalDecl) - : MetaM (OC (Lean.Name × Expr)) := do - withExistingLocalDecls originalVarsDecls.toList do - let oc := OC.mk (`_, objFun) constraints.toArray - return oc - --- TODO: Better error message when discovering a concave atom where convex is expected, and vice versa. -/-- Construct the atom tree. -/ -def mkAtomTree (objCurv : Curvature) (originalVarsDecls extraDecls : Array LocalDecl) (oc : OC Expr) (extraVars : Array FVarId := #[]) : - MetaM ( - OC Bool × - OC (Array MessageData) × - OC (Tree GraphAtomData Expr) × - OC (Tree (Array Expr) Unit) × - OC (Tree Curvature Curvature) × - OC (Tree (Array Expr) (Array Expr))) := do -withExistingLocalDecls originalVarsDecls.toList do - withExistingLocalDecls extraDecls.toList do - let decls := (← getLCtx).decls.toList.filterMap fun decl? => - if let some decl := decl? then some decl.type else none - trace[Meta.debug] "decls in mkAtomTree: {decls}" - let xs := originalVarsDecls.map fun decl => mkFVar decl.fvarId - trace[Meta.debug] "mkAtomTree xs: {xs}" - -- Find atoms. - let atomsAndArgs ← OC.map2M (fun e c => findAtoms e (xs.map (·.fvarId!) ++ extraVars) c) oc - ⟨objCurv, oc.constr.map (fun _ => Curvature.ConvexSet)⟩ - let failedAtom : OC Bool := atomsAndArgs.map (·.fst) - let failedAtomMsgs : OC (Array MessageData) := atomsAndArgs.map (·.snd.fst) - if failedAtom.objFun then - throwError "Failure in objective: {failedAtomMsgs.objFun}" - - let atoms := atomsAndArgs.map (·.snd.snd.fst) - let args := atomsAndArgs.map (·.snd.snd.snd.fst) - let curvature := atomsAndArgs.map (·.snd.snd.snd.snd.fst) - let bconds := atomsAndArgs.map (·.snd.snd.snd.snd.snd) - return (failedAtom, failedAtomMsgs, atoms, args, curvature, bconds) - -/-- Identify vconditions and check that all failed constraints can be used as -vconditions. -/ -def mkVConditions (originalVarsDecls : Array LocalDecl) (oc : OC Expr) - (constraints : List (Lean.Name × Expr)) (atoms : OC (Tree GraphAtomData Expr)) (args : OC (Tree (Array Expr) Unit)) - (failedAtom : OC Bool) (failedAtomMsgs : OC (Array MessageData)) (originalConstrVars : Array LocalDecl) : - MetaM (OC (Tree (Array ℕ) Unit) × Array Bool × OC (Tree (Array Expr) Unit)) := do -withExistingLocalDecls originalVarsDecls.toList do - let vcondData ← OC.map2M (findVConditions originalConstrVars oc.constr) atoms args - - -- Extract only indicies - let vcondIdx : OC (Tree (Array ℕ) Unit) := vcondData.map - (fun t => t.map (fun iores => iores.foldl (init := #[]) (fun (acc : Array ℕ) iore => - match iore with - | Sum.inl i => acc.push i - | Sum.inr _ => acc)) - id) - - let isVCond := vcondIdx.fold (oc.constr.map (fun _ => false)) - fun acc vcondIdxTree => - vcondIdxTree.fold acc fun acc is => - is.foldl (fun acc i => acc.set! i true) acc - let vcondVars := vcondData.map $ mkVCondVars (originalConstrVars.map LocalDecl.fvarId) - for i in [:isVCond.size] do - if failedAtom.constr[i]! ∧ ¬ isVCond[i]! then - throwError "Failure in constraint {constraints.toArray[i]!.1}: {failedAtomMsgs.constr[i]!}" - return (vcondIdx, isVCond, vcondVars) - -/-- -/ -def mkSolEqAtomOC (originalVarsDecls : Array LocalDecl) - (atoms : OC (Tree GraphAtomData Expr)) - (canonWithSolution : OC CanonExprsWithSolutionTree) - (vcondVars : OC (Tree (Array Expr) Unit)) (originalConstrVars : Array LocalDecl) : MetaM (OC (Tree Expr Expr)) := -withExistingLocalDecls originalVarsDecls.toList do - withExistingLocalDecls originalConstrVars.toList do - let solEqAtom ← OC.map3M mkSolEqAtom atoms canonWithSolution vcondVars - trace[Meta.debug] "solEqAtom {solEqAtom}" - return solEqAtom - -/-- -/ -def mkFeasibilityOC (originalVarsDecls : Array LocalDecl) - (atoms : OC GraphAtomDataTree) - (canonWithSolution : OC CanonExprsWithSolutionTree) - (bconds : OC BCondsTree) - (vcondVars : OC VCondsTree) - (originalConstrVars : Array LocalDecl) - (solEqAtom : OC SolEqAtomProofsTree) : - MetaM (OC FeasibilityProofsTree) := -withExistingLocalDecls originalVarsDecls.toList do - withExistingLocalDecls originalConstrVars.toList do - let feasibility ← OC.map5M mkFeasibility atoms canonWithSolution bconds vcondVars solEqAtom - trace[Meta.debug] "feasibility {feasibility}" - return feasibility -/-- -/ -def mkCanonExprsOC (originalVarsDecls : Array LocalDecl) (newVarDecls : List LocalDecl) - (atoms : OC (Tree GraphAtomData Expr)) (newVars : OC (Tree (Array LocalDecl) Unit)) : MetaM (OC (Tree Expr Expr)) := -withExistingLocalDecls originalVarsDecls.toList do - withExistingLocalDecls newVarDecls do - let canonExprs ← OC.map2M mkCanonExprs atoms newVars - trace[Meta.debug] "canonExprs {canonExprs}" - return canonExprs - -/-- -/ -def mkNewConstrsOC (originalVarsDecls : Array LocalDecl) (newVarDecls : List LocalDecl) - (bconds : OC BCondsTree) - (atoms : OC (Tree GraphAtomData Expr)) (newVars : OC (Tree (Array LocalDecl) Unit)) (canonExprs : OC (Tree Expr Expr)) - : MetaM (Array Expr × OC (Tree (Array LocalDecl) Unit) × Array LocalDecl):= -withExistingLocalDecls (originalVarsDecls.toList) do - withExistingLocalDecls newVarDecls do - let newConstrs ← OC.map4M mkNewConstrs atoms bconds newVars canonExprs - trace[Meta.debug] "newConstrs {newConstrs}" - let newConstrVars ← OC.mapM mkNewConstrVars newConstrs - trace[Meta.debug] "newConstrs {newConstrs}" - let newConstrs : Array Expr := newConstrs.fold #[] fun acc cs => - cs.fold acc Array.append - let newConstrVarsArray : Array LocalDecl := newConstrVars.fold #[] fun acc cs => - cs.fold acc Array.append - return (newConstrs, newConstrVars, newConstrVarsArray) - -/-- -/ -def mkOptimalityAndVCondElimOC (originalVarsDecls : Array LocalDecl) (newVarDecls : List LocalDecl) - (newConstrVarsArray : Array LocalDecl) - (atoms : OC (Tree GraphAtomData Expr)) - (args : OC (Tree (Array Expr) Unit)) - (canonExprs : OC (Tree Expr Expr)) - (newVars : OC (Tree (Array LocalDecl) Unit)) - (newConstrVars : OC (Tree (Array LocalDecl) Unit)) - (curvature : OC (Tree Curvature Curvature)) - (bconds : OC (Tree (Array Expr) (Array Expr))) - (vcondIdx : OC (Tree (Array ℕ) Unit)) - : MetaM (OC (Tree Expr Expr) × Std.HashMap ℕ Expr):= - withExistingLocalDecls originalVarsDecls.toList do - withExistingLocalDecls newVarDecls do - withExistingLocalDecls newConstrVarsArray.toList do - let optimalityAndVCondElim ← OC.map7M mkOptimalityAndVCondElim atoms args canonExprs newVars newConstrVars curvature bconds - let optimality := optimalityAndVCondElim.map (fun oce => oce.map Prod.fst Prod.fst) - let vcondElim := optimalityAndVCondElim.map (fun oce => oce.map Prod.snd Prod.snd) - trace[Meta.debug] "optimality {optimality}" - trace[Meta.debug] "vcondElim {vcondElim}" - - trace[Meta.debug] "vcondIdx {vcondIdx}" - let vcondElimWithIdx ← OC.map2M (fun a b => Tree.zip a b) vcondIdx vcondElim - let vcondElimMap := vcondElimWithIdx.fold {} - fun (map : Std.HashMap Nat Expr) ci => - ci.fold map fun (map : Std.HashMap Nat Expr) ci => Id.run do - let mut res := map - for i in [:ci.1.size] do - res := res.insert ci.1[i]! ci.2[i]! - return res - return (optimality, vcondElimMap) - -/-- -/ -structure ProcessedAtomTree where - (originalVarsDecls : Array LocalDecl) - (originalConstrVars : Array LocalDecl) - (newVarDecls : List LocalDecl) - (newConstrs : Array Expr) - (newConstrVarsArray : Array LocalDecl) - (forwardImagesNewVars : Array Expr) - (constraints : List (Lean.Name × Expr)) - (isVCond : Array Bool) - (vcondElimMap : Std.HashMap ℕ Expr) - (solEqAtom : OC (Tree Expr Expr)) - (feasibility : OC (Tree (Array Expr) Unit)) - (canonExprs : OC (Tree Expr Expr)) - (optimality : OC (Tree Expr Expr)) - -instance : Inhabited ProcessedAtomTree := - ⟨⟨#[], #[], [], #[], #[], #[], [], #[], {}, default, default, default, default⟩⟩ - -/-- -/ -def mkProcessedAtomTree (objCurv : Curvature) (objFun : Expr) - (constraints : List (Lean.Name × Expr)) (originalVarsDecls : Array LocalDecl) - (extraDecls : Array LocalDecl := #[]) (extraVars : Array FVarId := #[]) : - MetaM ProcessedAtomTree := do - let oc ← mkOC objFun constraints originalVarsDecls - let (failedAtom, failedAtomMsgs, atoms, args, curvature, bconds) ← mkAtomTree objCurv originalVarsDecls (extraDecls := extraDecls) (extraVars := extraVars) oc - let originalConstrVars ← mkOriginalConstrVars originalVarsDecls constraints.toArray - let (vcondIdx, isVCond, vcondVars) ← mkVConditions originalVarsDecls oc constraints atoms args failedAtom - failedAtomMsgs originalConstrVars - let newVars ← withExistingLocalDecls originalVarsDecls.toList do - OC.map2MwithCounter mkNewVars atoms args - let newVarDecls ← mkNewVarDeclList newVars - let canonWithSolution ← withExistingLocalDecls originalVarsDecls.toList do - OC.mapM mkCanonWithSolution atoms - let forwardImagesNewVars ← withExistingLocalDecls originalVarsDecls.toList do - mkForwardImagesNewVars canonWithSolution - let solEqAtom ← mkSolEqAtomOC originalVarsDecls atoms canonWithSolution vcondVars originalConstrVars - let feasibility ← mkFeasibilityOC originalVarsDecls atoms canonWithSolution bconds vcondVars - originalConstrVars solEqAtom - let canonExprs ← mkCanonExprsOC originalVarsDecls newVarDecls atoms newVars - let (newConstrs, newConstrVars, newConstrVarsArray) - ← mkNewConstrsOC originalVarsDecls newVarDecls bconds atoms newVars canonExprs - let (optimality, vcondElimMap) ← mkOptimalityAndVCondElimOC originalVarsDecls newVarDecls - newConstrVarsArray atoms args canonExprs newVars newConstrVars curvature bconds vcondIdx - - return ProcessedAtomTree.mk - (originalVarsDecls := originalVarsDecls) - (originalConstrVars := originalConstrVars) - (newVarDecls := newVarDecls) - (newConstrs := newConstrs) - (newConstrVarsArray := newConstrVarsArray) - (forwardImagesNewVars := forwardImagesNewVars) - (constraints := constraints) - (isVCond := isVCond) - (vcondElimMap := vcondElimMap) - (solEqAtom := solEqAtom) - (feasibility := feasibility) - (canonExprs := canonExprs) - (optimality := optimality) +/-- Build a proof of `StrongEquivalence.psi_optimality`. The proofs here come from optimality proofs +on the objective function component. -/ +def makeObjFunBackward (newDomain : Expr) (canonProblem : Expr) (xs : Array Expr) (ys : Array Expr) + (objFunOpt : Expr) (canonConstrs : Array Expr) (newConstrs : Array Expr) + (newConstrVars : Array LocalDecl) : MetaM Expr := do + -- `∀ {x : E}, q.feasible x → p.objFun (ψ x) ≤ q.objFun x`. + withLocalDeclD `p newDomain fun p => do + let prs := (← Meta.mkProjections newDomain p).map (·.2.2) + + withLocalDeclD `h (← mkAppM ``Minimization.constraints #[canonProblem, p]) fun h => do + let (_, cprs) := Meta.composeAndWithProj (canonConstrs ++ newConstrs).toList + let hProj := cprs h + let objFunBackwardBody := objFunOpt + + let objFunBackwardBody := objFunBackwardBody.replaceFVars + (newConstrVars.map (mkFVar ·.fvarId)) + (hProj.drop (hProj.length - newConstrVars.size)).toArray + + let objFunBackwardBody := objFunBackwardBody.replaceFVars + (xs ++ ys) prs.toArray + + let objFunBackward ← mkLambdaFVars #[p, h] objFunBackwardBody + trace[CvxLean.debug] "ψ optimality proof: {objFunBackward}." + trace[CvxLean.debug] "ψ optimality proof type: {← inferType objFunBackward}." + check objFunBackward + + return objFunBackward open Meta Elab Tactic -/-- TODO -/ +/-- Main canonization procedure. Given a minimization expression, return the canonized problem, in +conic form, and a proof of equivalence. -/ def canonize (ogProblem : MinimizationExpr) : MetaM (MinimizationExpr × Expr) := do let ogProblemExpr := ogProblem.toExpr let D := ogProblem.domain let R := ogProblem.codomain -- Get `objFun` and `constraints` without projections (e.g., `p.1 + p.2`) but rather with - -- declared variables in `originalVarsDecls` (e.g., `x + y`). + -- declared variables in `originalVarsDecls` (e.g., `x + y`). let (objFun, constraints, originalVarsDecls) ← withLambdaBody ogProblem.constraints fun p constraints => do let pr := (← Meta.mkProjections D p).toArray @@ -897,15 +200,15 @@ def canonize (ogProblem : MinimizationExpr) : MetaM (MinimizationExpr × Expr) : let objFun ← Meta.replaceProjections objFunP p.fvarId! xs return (objFun, constraints, originalVarsDecls) - -- Process the atom tree. + -- Make processed atom tree. let pat ← mkProcessedAtomTree Curvature.Convex objFun constraints originalVarsDecls - -- Create canon problem and equivalence proof. + -- Create canonized problem and equivalence proof. withExistingLocalDecls pat.originalVarsDecls.toList do -- Original problem variables. let originalVars := pat.originalVarsDecls.map fun decl => mkFVar decl.fvarId - -- Forward map: φ. + -- Make forward map (`φ`), with properties. let forwardMap ← makeForwardMap D originalVars pat.forwardImagesNewVars let (objFunForward, constrForward) ← @@ -913,20 +216,20 @@ def canonize (ogProblem : MinimizationExpr) : MetaM (MinimizationExpr × Expr) : let objFunForward ← makeObjFunForward D originalVars pat.originalConstrVars ogProblemExpr (pat.constraints.toArray.map Prod.snd) pat.solEqAtom.objFun.val let constrForward ← makeConstrForward D originalVars pat.originalConstrVars ogProblemExpr - (pat.constraints.toArray.map Prod.snd) pat.isVCond (pat.solEqAtom.constr.map Tree.val) + (pat.constraints.toArray.map Prod.snd) pat.isVCond (pat.solEqAtom.constrs.map Tree.val) pat.feasibility return (objFunForward, constrForward) withExistingLocalDecls pat.newVarDecls do - -- New variables added by the canonization. + -- New variables introduced by the canonization. let newVars := (pat.newVarDecls.map (mkFVar ·.fvarId)).toArray - -- Canon variables: originalVars ⊎ newVars. + -- Canonized variables: originalVars ⊎ newVars. let canonVars ← (originalVars ++ newVars).mapM fun x => do let decl ← x.fvarId!.getDecl return (decl.userName, decl.type) - -- New domain: D × T where T is the domain of the new variables. + -- New domain: `D × T` where `T` is the domain of the new variables. let E := Meta.composeDomain canonVars.toList -- Function to replace variables by projections in the new domain. @@ -936,8 +239,8 @@ def canonize (ogProblem : MinimizationExpr) : MetaM (MinimizationExpr × Expr) : let e := Expr.replaceFVars e (originalVars ++ newVars) prs.toArray mkLambdaFVars #[p] e - -- Canon problem. - let canonConstrs := pat.canonExprs.constr.map Tree.val + -- Canonize problem. + let canonConstrs := pat.canonExprs.constrs.map Tree.val let canonConstrs := canonConstrs.filterIdx (fun i => ¬ pat.isVCond[i]!) let canonProblem : MinimizationExpr := { domain := E @@ -946,7 +249,7 @@ def canonize (ogProblem : MinimizationExpr) : MetaM (MinimizationExpr × Expr) : constraints := ← mkDomain <| Meta.composeAnd (canonConstrs ++ pat.newConstrs).toList } let canonProblemExpr := canonProblem.toExpr - -- Backward map: ψ. + -- Make backward map (`ψ`), with properties. let backwardMap ← makeBackwardMap originalVars mkDomain let (objFunBackward, constrBackward) ← @@ -955,25 +258,25 @@ def canonize (ogProblem : MinimizationExpr) : MetaM (MinimizationExpr × Expr) : pat.optimality.objFun.val canonConstrs pat.newConstrs pat.newConstrVarsArray let constrBackward ← makeConstrBackward pat.vcondElimMap E canonProblemExpr originalVars - newVars (pat.optimality.constr.map (·.val)) canonConstrs pat.newConstrs + newVars (pat.optimality.constrs.map (·.val)) canonConstrs pat.newConstrs pat.newConstrVarsArray return (objFunBackward, constrBackward) - -- Combine forward and backward maps into equivalence witness. + -- Combine forward and backward maps and their properties into an equivalence witness. let strongEqvProof ← mkAppOptM ``Minimization.StrongEquivalence.mk #[D, E, R, none, ogProblemExpr, canonProblemExpr, - -- phi + -- `phi`. forwardMap, - -- psi + -- `psi`. backwardMap, - -- phi_feasibility + -- `phi_feasibility` constrForward, - -- psi_feasibility + -- `psi_feasibility` constrBackward, - -- phi_optimality + -- `phi_optimality` objFunForward, - -- psi_optimality + -- `psi_optimality` objFunBackward] let eqvProof ← mkAppM ``Minimization.Equivalence.ofStrongEquivalence #[strongEqvProof] @@ -983,7 +286,7 @@ def dcpBuilder : EquivalenceBuilder Unit := fun eqvExpr g => g.withContext do let ogProblem ← eqvExpr.toMinimizationExprLHS let (_, eqvProof) ← canonize ogProblem if ! (← isDefEq (mkMVar g) eqvProof) then - throwError "DCP failed to prove equivalence" + throwDCPError "failed to prove equivalence." end DCP @@ -991,6 +294,9 @@ namespace Tactic open Lean.Elab Lean.Elab.Tactic +/-- Tactic to canonize a problem into DCP form. It can be used directly in an `equivalence` or +`reduction` environment. However, note, that the main use of this transformation is to solve +problems using the `solve` command. -/ syntax (name := dcp) "dcp" : tactic @[tactic dcp] diff --git a/CvxLean/Tactic/DCP/DCPFindAtoms.lean b/CvxLean/Tactic/DCP/DCPFindAtoms.lean new file mode 100644 index 00000000..d2a7b619 --- /dev/null +++ b/CvxLean/Tactic/DCP/DCPFindAtoms.lean @@ -0,0 +1,159 @@ +import CvxLean.Meta.Util.Expr +import CvxLean.Tactic.Arith.Arith +import CvxLean.Tactic.DCP.AtomExt +import CvxLean.Tactic.DCP.DCPTypes + +/-! +# Finding DCP atom trees + +This file defines the procedure to turn an expression into a tree of atoms. This is the first step +of the canonization algorithm. At this point, only background conditions and curvatures are checked. + +This is then applied to every component in `mkAtomTree` in `CvxLean/Tactic/DCP/DCPMakers.lean` to +make a full atom tree for the problem. +-/ + +namespace CvxLean + +open Lean Meta Expr + +namespace DCP + +/-- Return list of all registered atoms that match with a given expression. For every registered +atom, it returns: +* The list of arguments as Lean expressions for further matching. +* The atom entry, of type `GraphAtomData`. + +Recall that the atom library is stored using a discrimination tree. + +There might be multiple atoms that match an expression. This function returns all matches. -/ +def findRegisteredAtoms (e : Expr) : MetaM (Array (Arguments × GraphAtomData)) := do + let discrTree ← getAtomDiscrTree + let atoms ← discrTree.getMatch (← zetaReduce (← e.removeMData)) + trace[Meta.debug] "discrTree {atoms.size} {e}" + let mut goodAtoms := #[] + for atom in atoms do + let (mvars, _, pattern) ← lambdaMetaTelescope atom.expr + if ← isDefEq pattern e then + let args ← mvars.mapM instantiateMVars + goodAtoms := goodAtoms.push (args, atom.graph!) + else + trace[CvxLean.debug] + "Pattern did not match. (Pattern {toString pattern}; Expression {toString e})" + -- Heuristic sorting of potential atoms to use: larger expressions have priority. + goodAtoms := goodAtoms.insertionSort (fun a b => (a.2.expr.size - b.2.expr.size != 0)) + return goodAtoms + +/-- Data type used to indicate whether an atom could be successfully processed. It is used by +`processAtom`. A successful match includes 4 trees: +* A tree of atom data (see `GraphAtomData`). +* A tree of arguments for every node. +* A tree of curvatures. +* A tree of background conditions. -/ +inductive FindAtomResult + | Success (res : AtomDataTrees) + | Error (msgs : Array MessageData) + +/-- Given an expression `e`, optimization variables `vars` and the expected curvature `curvature`, +this function attempts to recursively match `e` with atoms for the library. + +It outputs all the necessary information as explained in `FindAtomResult`. The boolean in the output +indicates whether a match from the library was used. -/ +partial def findAtoms (e : Expr) (vars : Array FVarId) (curvature : Curvature) : + MetaM (Bool × Array MessageData × AtomDataTrees) := do + let e := eta e + + -- Variable case. + if e.isFVar ∧ vars.contains e.fvarId! then + return (false, #[], Tree.leaf e, Tree.leaf (), Tree.leaf curvature, Tree.leaf #[]) + + -- Constant case. + let isConstantExpr := e.isRelativelyConstant vars || curvature == Curvature.Constant + let isPropExpr := (← inferType e).isProp + if isConstantExpr && !isPropExpr then + return (false, #[], Tree.leaf e, Tree.leaf (), Tree.leaf curvature, Tree.leaf #[]) + + -- Operator case. + let potentialAtoms ← findRegisteredAtoms e + let mut failedAtoms : Array MessageData := #[] + if potentialAtoms.size == 0 then + failedAtoms := failedAtoms.push m!"No atom found for {e}" + -- Go through all the matches, and stop on the first successful one. + for (args, atom) in potentialAtoms do + match ← processAtom e vars curvature atom args with + | FindAtomResult.Success (atoms, args, curvatures, bconds) => + trace[Meta.debug] "Success {e} : {atom}" + return (false, failedAtoms, atoms, args, curvatures, bconds) + | FindAtomResult.Error msg => + failedAtoms := failedAtoms ++ msg + return (true, failedAtoms, Tree.leaf e, Tree.leaf (), Tree.leaf curvature, Tree.leaf #[]) +where + /-- This functon takes data from the atom library and succeeds if it can use it to fu-/ + processAtom (e : Expr) (vars : Array FVarId) (curvature : Curvature) (atom : GraphAtomData) + (args : Array Expr) : MetaM FindAtomResult := do + trace[CvxLean.debug] "Processing atom {atom.expr} for expression {e}, curvature {curvature}" + if atom.curvature != Curvature.Affine ∧ curvature != atom.curvature then + return FindAtomResult.Error + #[m!"Trying atom {atom.expr} for expression {e}: " ++ + m!"Expected {curvature}, but atom is {atom.curvature}"] + let mut bconds := #[] + for (_, bcondType) in atom.bconds do + let bcondType := mkAppNBeta bcondType args + let fvarId? ← (← getLCtx).decls.findSomeM? fun decl => match decl with + | none => pure none + | some decl => do + if ← isDefEq decl.type bcondType then return some decl.fvarId else return none + match fvarId? with + | some fvarId => + bconds := bconds.push (mkFVar fvarId) + | none => + -- Try to prove simple bconditions by `arith`. + let (e, _) ← Lean.Elab.Term.TermElabM.run <| Lean.Elab.Term.commitIfNoErrors? <| do + let v ← Lean.Elab.Term.elabTerm (← `(by arith)).raw (some bcondType) + Lean.Elab.Term.synthesizeSyntheticMVarsNoPostponing + let v ← instantiateMVars v + return v + if let some e' := e then + bconds := bconds.push e' + else + let declsTypes := (← getLCtx).decls.toList.filterMap (Option.map LocalDecl.type) + return FindAtomResult.Error + #[m!"Trying atom {atom.expr} for expression {e}: " ++ + m!"Background Condition {bcondType} not found. (Local context: {declsTypes})"] + + let mut childTrees := #[] + let mut childArgsTrees := #[] + let mut childCurvatures := #[] + let mut childBConds := #[] + for i in [:args.size] do + let arg := args[i]! + if atom.argKinds[i]! == ArgKind.Constant ∧ not (arg.isRelativelyConstant vars) then + return FindAtomResult.Error + #[m!"Trying atom {atom.expr} for expression {e}: " ++ + m!"Expected constant argument, but found: {arg}"] + -- Apply curvature rule to find expected curvature of the argument. + let c := curvatureInArg curvature atom.argKinds[i]! + trace[CvxLean.debug] "Trying to find atoms for {arg} with curvature {c}." + -- Recursion happends here, where we try to find atoms + let (childFailed, childFailedAtoms, childTree, childArgsTree, childCurvature, childBCond) ← + findAtoms arg vars c + if childFailed then + return FindAtomResult.Error childFailedAtoms + childTrees := childTrees.push childTree + childArgsTrees := childArgsTrees.push childArgsTree + childCurvatures := childCurvatures.push childCurvature + childBConds := childBConds.push childBCond + if curvature == Curvature.Affine then + -- For affine branches, we simply return the expression. + return FindAtomResult.Success + (Tree.leaf e, Tree.leaf (), Tree.leaf curvature, Tree.leaf bconds) + else + return FindAtomResult.Success ( + Tree.node atom childTrees, + Tree.node args childArgsTrees, + Tree.node curvature childCurvatures, + Tree.node bconds childBConds) + +end DCP + +end CvxLean diff --git a/CvxLean/Tactic/DCP/DCPMakers.lean b/CvxLean/Tactic/DCP/DCPMakers.lean new file mode 100644 index 00000000..2587b745 --- /dev/null +++ b/CvxLean/Tactic/DCP/DCPMakers.lean @@ -0,0 +1,517 @@ +import CvxLean.Meta.Util.Expr +import CvxLean.Meta.Util.Error +import CvxLean.Tactic.Arith.Arith +import CvxLean.Tactic.DCP.DCPTypes +import CvxLean.Tactic.DCP.DCPFindAtoms + +/-! +# DCP makers + +This is the "processing" step of the DCP procedure. Given a starting problem, we make an atom tree, +and package all the necessary information that is needed to build the equivalence into a +`ProcessedAtomTree` (see `mkProcessedAtomTree`). + +See `CvxLean/Tactic/DCP/DCP.lean` for the canonization algorithm that uses the processed atom tree +to build the equivalence. +-/ + +namespace CvxLean + +open Lean Meta Expr + +namespace DCP + +section MainMakers + +/-- Create local declarations for each of the constraints in the original problem. This function is +the only one in this section that does not operate on the canonized problem. -/ +def mkOriginalConstrVars (originalVarsDecls : Array LocalDecl) (constr : Array (Name × Expr)) : + MetaM (Array LocalDecl) := do + withExistingLocalDecls originalVarsDecls.toList do + let mut decls : Array LocalDecl := #[] + for i in [:constr.size] do + decls := decls.push <| LocalDecl.cdecl 0 (← mkFreshFVarId) (Name.mkNum constr[i]!.1 i) + constr[i]!.2 Lean.BinderInfo.default LocalDeclKind.default + return decls + +/-- Build the atom tree for the whole problem. Essentially, use `findAtoms` on every component, +setting the expected curvature of the objective function to "convex" and the expected curvature for +each of the constraints to "convex set". -/ +def mkAtomTree (objCurv : Curvature) (originalVarsDecls extraDecls : Array LocalDecl) (oc : OC Expr) + (extraVars : Array FVarId := #[]) : + MetaM ( + OC Bool × OC (Array MessageData) × + OC GraphAtomDataTree × OC ArgumentsTree × OC CurvatureTree × OC BCondsTree) := do + withExistingLocalDecls originalVarsDecls.toList do + withExistingLocalDecls extraDecls.toList do + let xs := originalVarsDecls.map fun decl => mkFVar decl.fvarId + -- Find atoms. + let atomsAndArgs ← OC.map2M (fun e c => findAtoms e (xs.map (·.fvarId!) ++ extraVars) c) oc + ⟨objCurv, oc.constrs.map (fun _ => Curvature.ConvexSet)⟩ + let failedAtom : OC Bool := atomsAndArgs.map (·.fst) + let failedAtomMsgs : OC (Array MessageData) := atomsAndArgs.map (·.snd.fst) + if failedAtom.objFun then + throwDCPError + "failed to make atom tree for the objective function ({failedAtomMsgs.objFun})." + + let atoms := atomsAndArgs.map (·.snd.snd.fst) + let args := atomsAndArgs.map (·.snd.snd.snd.fst) + let curvature := atomsAndArgs.map (·.snd.snd.snd.snd.fst) + let bconds := atomsAndArgs.map (·.snd.snd.snd.snd.snd) + + return (failedAtom, failedAtomMsgs, atoms, args, curvature, bconds) + +/-- Returns a tree of the new variables added at every node in the atom tree and the total number of +new variables. Freshness is guaranteed by giving each variable a unique name here, using the +counter. -/ +partial def mkNewVars (i : Nat) : GraphAtomDataTree → ArgumentsTree → MetaM (NewVarsTree × Nat) + | Tree.node atom childAtoms, Tree.node args childArgs => do + let mut childNewVars := #[] + let mut i := i + for k in [:childAtoms.size] do + let (childNewVar, i') ← mkNewVars i childAtoms[k]! childArgs[k]! + i := i' + childNewVars := childNewVars.push $ childNewVar + let mut newVars : Array LocalDecl := #[] + for (n, ty) in atom.impVars do + newVars := newVars.push <| + LocalDecl.cdecl 0 (← mkFreshFVarId) (Name.mkNum n i) (mkAppNBeta ty args) + Lean.BinderInfo.default LocalDeclKind.default + i := i + 1 + return (Tree.node newVars childNewVars, i) + | Tree.leaf _, Tree.leaf _ => pure (Tree.leaf (), i) + | _, _ => throwDCPError "new variables tree mismatch." + +/-- Turn `NewVarTree`s (which hold local declarations at the nodes) into one list of local +declarations. -/ +def mkNewVarDeclList (newVars : OC NewVarsTree) : MetaM (List LocalDecl) := do + let newVarTrees := #[newVars.objFun] ++ newVars.constrs + let mut newVarDecls := #[] + for t in newVarTrees do + newVarDecls := t.fold newVarDecls Array.append + return newVarDecls.toList + +/-- Extract the atoms' implementation objectives from the atom tree so that the resulting tree +represents the canonized expressions. -/ +partial def mkCanonExprs : GraphAtomDataTree → NewVarsTree → MetaM CanonExprsTree + | Tree.node atom childAtoms, Tree.node newVars childNewVars => do + let mut childCanonExprs := #[] + for i in [:childAtoms.size] do + childCanonExprs := childCanonExprs.push <| ← mkCanonExprs childAtoms[i]! childNewVars[i]! + let canonExpr := atom.impObjFun + let canonExpr := mkAppNBeta canonExpr (childCanonExprs.map (·.val)) + let canonExpr := mkAppNBeta canonExpr (newVars.map (mkFVar ·.fvarId)) + return Tree.node canonExpr childCanonExprs + | Tree.leaf e, Tree.leaf () => pure <| Tree.leaf e + | _, _ => throwDCPError "canonized expressions tree mismatch." + +/-- Put together all the new constraints added by each of the atoms in the atom tree. -/ +partial def mkNewConstrs : GraphAtomDataTree → BCondsTree → NewVarsTree → CanonExprsTree → + MetaM NewConstrsTree + | Tree.node atom childAtoms, Tree.node bconds childBConds, Tree.node newVars childNewVars, + Tree.node _canonExprs childCanonExprs => do + let mut childNewConstrs := #[] + for i in [:childAtoms.size] do + childNewConstrs := childNewConstrs.push <| + ← mkNewConstrs childAtoms[i]! childBConds[i]! childNewVars[i]! childCanonExprs[i]! + let newConstrs := atom.impConstrs + let newConstrs := newConstrs.map (mkAppNBeta · (childCanonExprs.map Tree.val)) + let newConstrs := newConstrs.map (mkAppNBeta · bconds) + let newConstrs := newConstrs.map (mkAppNBeta · (newVars.map (mkFVar ·.fvarId))) + return Tree.node newConstrs childNewConstrs + | Tree.leaf _, Tree.leaf _, Tree.leaf _, Tree.leaf _ => pure <| Tree.leaf () + | _, _, _, _ => throwDCPError "new constraints tree mismatch." + +/-- Turn a tree of new constraints into a tree of new constraint *variables*, creating a local +declaration for each new constraint. -/ +partial def mkNewConstrVars : NewConstrsTree → MetaM NewConstrVarsTree + | Tree.node newConstr childNewConstr => do + -- Recursive calls for arguments. + let mut childNewConstrVar := #[] + for i in [:childNewConstr.size] do + childNewConstrVar := childNewConstrVar.push <| ← mkNewConstrVars childNewConstr[i]! + -- Create variables for the new constraints. + let mut newConstrVars : Array LocalDecl := #[] + for ty in newConstr do + newConstrVars := newConstrVars.push <| + LocalDecl.cdecl 0 (← mkFreshFVarId) `h ty Lean.BinderInfo.default LocalDeclKind.default + return Tree.node newConstrVars childNewConstrVar + | Tree.leaf _ => pure <| Tree.leaf () + +/-- Same as `mkCanonExprs` but also attach the solutions for each of the atoms unfolded. -/ +partial def mkCanonWithSolution : GraphAtomDataTree → MetaM CanonExprsWithSolutionTree + | Tree.node atom childAtoms => do + let mut childCanonExprs := #[] + for i in [:childAtoms.size] do + childCanonExprs := childCanonExprs.push <| ← mkCanonWithSolution childAtoms[i]! + let sols := (atom.solution.map (mkAppNBeta · (childCanonExprs.1.map (·.val.1)).toArray)) + let canonExpr := atom.impObjFun + let canonExpr := mkAppNBeta canonExpr (childCanonExprs.1.map (·.val.1)).toArray + let canonExpr := mkAppNBeta canonExpr sols + return Tree.node ⟨canonExpr, sols⟩ childCanonExprs + | Tree.leaf e => + return Tree.leaf ⟨e, #[]⟩ + +/-- Combine all the solutions introduced by every atom expansion. A solution tells us how to assign +new variables from old ones. Here, we collect all such assignments, which will be used later on to +build the forward map between the original and canonized problems. -/ +def mkForwardImagesNewVars (canonWithSolution : OC CanonExprsWithSolutionTree) : + MetaM Solution := do + let canonWithSolutionTrees := #[canonWithSolution.objFun] ++ canonWithSolution.constrs + let mut fm := #[] + for t in canonWithSolutionTrees do + fm := t.fold fm (fun acc a => Array.append acc a.solution) + return fm + +/-- Combine all proofs of "solution-equals-atom" into one tree. A little work is needed here so that +the solutions are in term of the optimization variables and not other "new" variables, so we need to +use the solution-equal-atom proofs from the children to adjust the proofs as required. -/ +partial def mkSolEqAtom : GraphAtomDataTree → CanonExprsWithSolutionTree → VCondsTree → + MetaM SolEqAtomProofsTree + | Tree.node atom childAtoms, + Tree.node _canonWithSolution childCanonWithSolution, + Tree.node vcondVars childVCondVars => do + -- Recursive calls for arguments. + let mut childSolEqAtom := #[] + for i in [:childAtoms.size] do + childSolEqAtom := childSolEqAtom.push <| + ← mkSolEqAtom childAtoms[i]! childCanonWithSolution[i]! childVCondVars[i]! + -- Rewrite arguments in atom expr. + let mut solEqAtomR ← mkEqRefl atom.expr + for c in childSolEqAtom do + solEqAtomR ← mkCongr solEqAtomR c.val + -- Use `solEqAtom` of children to rewrite the arguments in the `vconditions`. + let mut vconds := #[] + for i in [:vcondVars.size] do + let mut vcondEqCanonVCond ← mkEqRefl atom.vconds[i]!.2 + for c in childSolEqAtom do + vcondEqCanonVCond ← mkCongr vcondEqCanonVCond c.val + vconds := vconds.push <| ← mkEqMPR vcondEqCanonVCond vcondVars[i]! + -- Apply solEqAtom property of the atom. + let solEqAtomL := atom.solEqAtom + let solEqAtomL := mkAppN solEqAtomL (childCanonWithSolution.map (·.val.1)) + let solEqAtomL := mkAppN solEqAtomL vconds + let solEqAtom ← mkEqTrans solEqAtomL solEqAtomR + return Tree.node solEqAtom childSolEqAtom + | Tree.leaf e, Tree.leaf _, Tree.leaf _ => do return Tree.leaf (← mkEqRefl e) + | _, _, _ => throwDCPError "solution-equals-atom tree mismatch." + +/-- Combine all feasibility proofs. Again, we need to replace the solutions using solution-equals- +atom so that all the feasibility proofs are with respect to the optimization variables. -/ +partial def mkFeasibility : GraphAtomDataTree → CanonExprsWithSolutionTree → BCondsTree → + VCondsTree → SolEqAtomProofsTree → MetaM FeasibilityProofsTree + | Tree.node atom childAtoms, + Tree.node _canonWithSolution childCanonWithSolution, + Tree.node bconds childBConds, + Tree.node vcondVars childVCondVars, + Tree.node _solEqAtom childSolEqAtom => do + -- Recursive calls for arguments. + let mut childFeasibility := #[] + for i in [:childAtoms.size] do + let c ← mkFeasibility childAtoms[i]! childCanonWithSolution[i]! childBConds[i]! + childVCondVars[i]! childSolEqAtom[i]! + childFeasibility := childFeasibility.push c + -- Use solEqAtom of children to rewrite the arguments in the vconditions. + let mut vconds := #[] + for i in [:vcondVars.size] do + let mut vcondEqCanonVCond ← mkEqRefl atom.vconds[i]!.2 + for c in childSolEqAtom do + vcondEqCanonVCond ← mkCongr vcondEqCanonVCond c.val + vconds := vconds.push <| ← mkEqMPR vcondEqCanonVCond vcondVars[i]! + -- Apply feasibility property of the atom. + let feasibility := atom.feasibility + let feasibility := feasibility.map (mkAppN · (childCanonWithSolution.map (·.val.1))) + let feasibility := feasibility.map (mkAppN · bconds) + let feasibility := feasibility.map (mkAppN · vconds) + let _ ← feasibility.mapM check + return Tree.node feasibility childFeasibility + | Tree.leaf _, Tree.leaf _, Tree.leaf _, Tree.leaf _, Tree.leaf _ => return Tree.leaf () + | _, _, _, _, _ => throwError "feasibility proofs tree mismatch." + +/-- Gather all the optimality and vcondition elimination proofs. Affine atoms in the role of convex +or concave need to be handled separately and their proof of optimality needs to be adjusted. -/ +partial def mkOptimalityAndVCondElim : GraphAtomDataTree → ArgumentsTree → CanonExprsTree → + NewVarsTree → NewConstrVarsTree → CurvatureTree → BCondsTree → + MetaM OptimalityAndVCondElimProofsTree + | Tree.node atom childAtoms, + Tree.node args childArgs, + Tree.node _canonExpr childCanonExpr, + Tree.node newVars childNewVars, + Tree.node newConstrVars childNewConstrVars, + Tree.node curvature childCurvature, + Tree.node bconds childBConds => do + -- Recursive calls for arguments. + let mut childOptimality := #[] + let mut childOptimalityFiltered := #[] + for i in [:childAtoms.size] do + if atom.argKinds[i]! != ArgKind.Constant && atom.argKinds[i]! != ArgKind.Neither then + let opt ← mkOptimalityAndVCondElim childAtoms[i]! childArgs[i]! childCanonExpr[i]! + childNewVars[i]! childNewConstrVars[i]! childCurvature[i]! childBConds[i]! + childOptimality := childOptimality.push opt + childOptimalityFiltered := childOptimalityFiltered.push opt + else + childOptimality := childOptimality.push <| Tree.leaf (mkStrLit s!"dummy {args[i]!}", #[]) + let mut monoArgs := #[] + for i in [:args.size] do + if atom.argKinds[i]! != ArgKind.Constant && atom.argKinds[i]! != ArgKind.Neither then + monoArgs := monoArgs.push args[i]! + + -- Apply optimality property of atom. Recall that the optimality property of affine atoms is + -- both the property of concave and convex atoms. That is so that here we can use either as + -- needed. + let optimality ← + match curvature, atom.curvature with + | Curvature.Concave, Curvature.Affine => mkAppM ``And.left #[atom.optimality] + | Curvature.Convex, Curvature.Affine => mkAppM ``And.right #[atom.optimality] + | _, _ => pure atom.optimality + let optimality := mkAppN optimality (childCanonExpr.map Tree.val) + let optimality := mkAppN optimality bconds + check optimality + let optimality := mkAppN optimality (newVars.map (mkFVar ·.fvarId)) + let optimality := mkAppN optimality (newConstrVars.map (mkFVar ·.fvarId)) + let optimality := mkAppN optimality monoArgs + let optimality := mkAppN optimality (childOptimalityFiltered.map (·.val.1)) + check optimality + + -- Apply vcond elim property of atom. + let vcondElim := atom.vcondElim + let vcondElim := vcondElim.map (mkAppN · (childCanonExpr.map Tree.val)) + let vcondElim := vcondElim.map (mkAppN · (newVars.map (mkFVar ·.fvarId))) + let vcondElim := vcondElim.map (mkAppN · (newConstrVars.map (mkFVar ·.fvarId))) + let vcondElim := vcondElim.map (mkAppN · monoArgs) + let vcondElim := vcondElim.map (mkAppN · (childOptimalityFiltered.map (·.val.1))) + + return Tree.node (optimality, vcondElim) childOptimality + | Tree.leaf e, Tree.leaf _, Tree.leaf _, Tree.leaf _, Tree.leaf _, Tree.leaf _, Tree.leaf _ => do + trace[CvxLean.debug] "Optimality Leaf {e}." + return Tree.leaf (← mkAppM ``le_refl #[e], #[]) + | _, _, _, _, _, _, _ => + throwDCPError "optimality and vcondition elimination proof tree mismatch." + +end MainMakers + +section VariableConditions + +/- These are the functions that handle variable conditions, including inference. -/ + +/-- Proofs of variable conditions may come from a constraint exactly (in which case we return the +constraint's identifier), or they might be inferred, in which case we return the proof. -/ +def mkVCondVars (originalConstrFVars : Array FVarId) (vcondIdx : PreVCondsTree) : + Tree (Array Expr) Unit := + vcondIdx.map (fun iores => iores.map fun iore => + match iore with + | Sum.inl i => mkFVar originalConstrFVars[i]! + | Sum.inr e => e) id + +/-- Returns index of constraint if vcondition corresponds exactly to a constraint (this is needed +for condition elimination). If not, it tries to deduce the condition from other constraints and +returns an expression. -/ +partial def findVConditions (originalConstrVars : Array LocalDecl) (constraints : Array Expr) : + GraphAtomDataTree → ArgumentsTree → MetaM PreVCondsTree + | Tree.node atom childAtoms, Tree.node args childArgs => do + let mut childrenVCondData := #[] + for i in [:childAtoms.size] do + let childVCondData ← + findVConditions originalConstrVars constraints childAtoms[i]! childArgs[i]! + childrenVCondData := childrenVCondData.push childVCondData + let mut vcondData := #[] + for (_, vcond) in atom.vconds do + let vcond := mkAppNBeta vcond args + + -- First, try to see if it matches exactly with any of the constraints. + match ← constraints.findIdxM? (isDefEq vcond) with + | some i => do vcondData := vcondData.push (Sum.inl i) + | none => + -- If not, try to infer vconditions from constraints. + vcondData := ← + withExistingLocalDecls originalConstrVars.toList do + let (e, _) ← Lean.Elab.Term.TermElabM.run <| Lean.Elab.Term.commitIfNoErrors? <| do + let tac ← `(by arith) + let v ← Lean.Elab.Term.elabTerm tac.raw (some vcond) + Lean.Elab.Term.synthesizeSyntheticMVarsNoPostponing + instantiateMVars v + + if let some e' := e then + return vcondData.push (Sum.inr e') + else + throwDCPError ("variable condition for {atom.id} not found or inferred." ++ + "\n{vcond}\n{constraints}.") + + return (Tree.node vcondData childrenVCondData) + | Tree.leaf _, Tree.leaf _ => pure (Tree.leaf ()) + | _, _ => throwDCPError "variable conditions tree mismatch." + +/-- Identify vconditions and check that all failed constraints can be used as vconditions. This also +separates the vconditions that need to be eliminated by putting all their indices together. -/ +def mkVConditions (originalVarsDecls : Array LocalDecl) (oc : OC Expr) + (constraints : List (Name × Expr)) (atoms : OC GraphAtomDataTree) (args : OC ArgumentsTree) + (failedAtom : OC Bool) (failedAtomMsgs : OC (Array MessageData)) + (originalConstrVars : Array LocalDecl) : + MetaM (OC VCondsIdxsTree × Array Bool × OC VCondsTree) := do + withExistingLocalDecls originalVarsDecls.toList do + let vcondData ← OC.map2M (findVConditions originalConstrVars oc.constrs) atoms args + + -- Extract only indicies + let vcondIdx : OC VCondsIdxsTree := vcondData.map + (fun t => t.map (fun iores => iores.foldl (init := #[]) (fun (acc : Array ℕ) iore => + match iore with + | Sum.inl i => acc.push i + | Sum.inr _ => acc)) + id) + + let isVCond := vcondIdx.fold (oc.constrs.map (fun _ => false)) + fun acc vcondIdxTree => + vcondIdxTree.fold acc fun acc is => + is.foldl (fun acc i => acc.set! i true) acc + let vcondVars := vcondData.map <| mkVCondVars (originalConstrVars.map LocalDecl.fvarId) + for i in [:isVCond.size] do + if failedAtom.constrs[i]! && !isVCond[i]! then + throwDCPError + "failure in constraint {constraints.toArray[i]!.1} ({failedAtomMsgs.constrs[i]!})." + return (vcondIdx, isVCond, vcondVars) + +end VariableConditions + +section OC + +/- Lift all the basic makers to operate on `OC`s. -/ + +/-- Extract objective and constraints in terms of the optimization variables. -/ +def mkOC (objFun : Expr) (constraints : List (Name × Expr)) (originalVarsDecls : Array LocalDecl) : + MetaM (OC Expr) := do + withExistingLocalDecls originalVarsDecls.toList do + return OC.mk objFun (constraints.toArray.map (fun (c : Name × Expr) => c.2)) + +/-- Same as `mkOC` but keeping constraint tags. The objective function is tagged with anonymous. -/ +def mkOCWithNames (objFun : Expr) (constraints : List (Name × Expr)) + (originalVarsDecls : Array LocalDecl) : MetaM (OC (Name × Expr)) := do + withExistingLocalDecls originalVarsDecls.toList do + let oc := OC.mk (`_, objFun) constraints.toArray + return oc + +/-- Lioft `mkSolEqAtom` to `OC`. -/ +def mkSolEqAtomOC (originalVarsDecls : Array LocalDecl) (atoms : OC GraphAtomDataTree) + (canonWithSolution : OC CanonExprsWithSolutionTree) (vcondVars : OC VCondsTree) + (originalConstrVars : Array LocalDecl) : MetaM (OC SolEqAtomProofsTree) := + withExistingLocalDecls originalVarsDecls.toList do + withExistingLocalDecls originalConstrVars.toList do + let solEqAtom ← OC.map3M mkSolEqAtom atoms canonWithSolution vcondVars + trace[CvxLean.debug] "`solEqAtom`: {solEqAtom}." + return solEqAtom + +/-- Lift `mkFeasibility` to `OC`. -/ +def mkFeasibilityOC (originalVarsDecls : Array LocalDecl) (atoms : OC GraphAtomDataTree) + (canonWithSolution : OC CanonExprsWithSolutionTree) (bconds : OC BCondsTree) + (vcondVars : OC VCondsTree) (originalConstrVars : Array LocalDecl) + (solEqAtom : OC SolEqAtomProofsTree) : MetaM (OC FeasibilityProofsTree) := + withExistingLocalDecls originalVarsDecls.toList do + withExistingLocalDecls originalConstrVars.toList do + let feasibility ← OC.map5M mkFeasibility atoms canonWithSolution bconds vcondVars solEqAtom + trace[CvxLean.debug] "`feasibility`: {feasibility}." + return feasibility + +/-- Lift `mkCanonExprs` to `OC`. -/ +def mkCanonExprsOC (originalVarsDecls : Array LocalDecl) (newVarDecls : List LocalDecl) + (atoms : OC GraphAtomDataTree) (newVars : OC NewVarsTree) : MetaM (OC CanonExprsTree) := + withExistingLocalDecls originalVarsDecls.toList do + withExistingLocalDecls newVarDecls do + let canonExprs ← OC.map2M mkCanonExprs atoms newVars + trace[CvxLean.debug] "`canonExprs`: {canonExprs}." + return canonExprs + +/-- Lift `mkNewConstrs` and `mkNewConstrVars` to `OC`. Also returns an array containing all the new +constraints as expressions, and an array containing pointers to all their local declarations. -/ +def mkNewConstrsOC (originalVarsDecls : Array LocalDecl) (newVarDecls : List LocalDecl) + (bconds : OC BCondsTree) (atoms : OC GraphAtomDataTree) (newVars : OC NewVarsTree) + (canonExprs : OC CanonExprsTree) : + MetaM (Array Expr × OC NewConstrVarsTree × Array LocalDecl) := + withExistingLocalDecls (originalVarsDecls.toList) do + withExistingLocalDecls newVarDecls do + let newConstrs ← OC.map4M mkNewConstrs atoms bconds newVars canonExprs + trace[CvxLean.debug] "`newConstrs` {newConstrs}." + let newConstrVars ← OC.mapM mkNewConstrVars newConstrs + trace[CvxLean.debug] "`newConstrsVars`: {newConstrs}." + let newConstrs : Array Expr := newConstrs.fold #[] fun acc cs => + cs.fold acc Array.append + let newConstrVarsArray : Array LocalDecl := newConstrVars.fold #[] fun acc cs => + cs.fold acc Array.append + return (newConstrs, newConstrVars, newConstrVarsArray) + +/-- Lift `optimalityAndVCondElim` to OC and split it into the optimality proofs and the vcondition +elimination map. -/ +def mkOptimalityAndVCondElimOC (originalVarsDecls : Array LocalDecl) (newVarDecls : List LocalDecl) + (newConstrVarsArray : Array LocalDecl) (atoms : OC GraphAtomDataTree) (args : OC ArgumentsTree) + (canonExprs : OC CanonExprsTree) (newVars : OC NewVarsTree) + (newConstrVars : OC NewConstrVarsTree) (curvature : OC CurvatureTree) (bconds : OC BCondsTree) + (vcondIdx : OC VCondsIdxsTree) : MetaM (OC OptimalityProofsTree × VCondElimMap):= + withExistingLocalDecls originalVarsDecls.toList do + withExistingLocalDecls newVarDecls do + withExistingLocalDecls newConstrVarsArray.toList do + let optimalityAndVCondElim ← OC.map7M mkOptimalityAndVCondElim atoms args canonExprs newVars + newConstrVars curvature bconds + let optimality := optimalityAndVCondElim.map (fun oce => oce.map Prod.fst Prod.fst) + let vcondElim := optimalityAndVCondElim.map (fun oce => oce.map Prod.snd Prod.snd) + trace[CvxLean.debug] "`optimality`: {optimality}." + trace[CvxLean.debug] "`vcondElim`: {vcondElim}." + trace[CvxLean.debug] "`vcondIdx`: {vcondIdx}." + let vcondElimWithIdx ← OC.map2M (fun a b => Tree.zip a b) vcondIdx vcondElim + let vcondElimMap := vcondElimWithIdx.fold {} + fun (map : VCondElimMap) ci => + ci.fold map fun (map : VCondElimMap) ci => Id.run do + let mut res := map + for i in [:ci.1.size] do + res := res.insert ci.1[i]! ci.2[i]! + return res + return (optimality, vcondElimMap) + +end OC + +/-- Combine everything into a processed atom tree, starting from the original optimization +problem. First, make the atom tree. Gather all the local declarations (original variables, original +constraints, new variables, new constraints). Get all the forward images (from the solutions of the +atoms), and the canonized expressions for every component. As expected, it also includes all the +necessary proofs, as trees. -/ +def mkProcessedAtomTree (objCurv : Curvature) (objFun : Expr) (constraints : List (Name × Expr)) + (originalVarsDecls : Array LocalDecl) (extraDecls : Array LocalDecl := #[]) + (extraVars : Array FVarId := #[]) : MetaM ProcessedAtomTree := do + let oc ← mkOC objFun constraints originalVarsDecls + let (failedAtom, failedAtomMsgs, atoms, args, curvature, bconds) ← mkAtomTree objCurv + originalVarsDecls (extraDecls := extraDecls) (extraVars := extraVars) oc + let originalConstrVars ← mkOriginalConstrVars originalVarsDecls constraints.toArray + let (vcondIdx, isVCond, vcondVars) ← mkVConditions originalVarsDecls oc constraints atoms args + failedAtom failedAtomMsgs originalConstrVars + let newVars ← withExistingLocalDecls originalVarsDecls.toList do + OC.map2MwithCounter mkNewVars atoms args + let newVarDecls ← mkNewVarDeclList newVars + let canonWithSolution ← withExistingLocalDecls originalVarsDecls.toList do + OC.mapM mkCanonWithSolution atoms + let forwardImagesNewVars ← withExistingLocalDecls originalVarsDecls.toList do + mkForwardImagesNewVars canonWithSolution + let solEqAtom ← mkSolEqAtomOC originalVarsDecls atoms canonWithSolution vcondVars + originalConstrVars + let feasibility ← mkFeasibilityOC originalVarsDecls atoms canonWithSolution bconds vcondVars + originalConstrVars solEqAtom + let canonExprs ← mkCanonExprsOC originalVarsDecls newVarDecls atoms newVars + let (newConstrs, newConstrVars, newConstrVarsArray) ← + mkNewConstrsOC originalVarsDecls newVarDecls bconds atoms newVars canonExprs + let (optimality, vcondElimMap) ← mkOptimalityAndVCondElimOC originalVarsDecls newVarDecls + newConstrVarsArray atoms args canonExprs newVars newConstrVars curvature bconds vcondIdx + + return ProcessedAtomTree.mk + (originalVarsDecls := originalVarsDecls) + (originalConstrVars := originalConstrVars) + (newVarDecls := newVarDecls) + (newConstrs := newConstrs) + (newConstrVarsArray := newConstrVarsArray) + (forwardImagesNewVars := forwardImagesNewVars) + (constraints := constraints) + (isVCond := isVCond) + (vcondElimMap := vcondElimMap) + (solEqAtom := solEqAtom) + (feasibility := feasibility) + (canonExprs := canonExprs) + (optimality := optimality) + +end DCP + +end CvxLean diff --git a/CvxLean/Tactic/DCP/DCPTypes.lean b/CvxLean/Tactic/DCP/DCPTypes.lean index e69de29b..9e4f7188 100644 --- a/CvxLean/Tactic/DCP/DCPTypes.lean +++ b/CvxLean/Tactic/DCP/DCPTypes.lean @@ -0,0 +1,207 @@ +import CvxLean.Tactic.DCP.Tree +import CvxLean.Tactic.DCP.OC +import CvxLean.Tactic.DCP.AtomExt + +/-! +# Types for DCP + +This file gives names to the main structures used by the DCP procedure for better readability. +-/ + +namespace CvxLean + +open Lean + +namespace DCP + +/- Atom data. -/ + +/-- Tree with all the atom data stored in the library. The leaves are just (affine or constant) +expressions. -/ +abbrev GraphAtomDataTree := Tree GraphAtomData Expr + + +/- Arguments. -/ + +/-- This is the type of an argument. -/ +abbrev Argument := Expr + +/-- An atom may have several arguments. -/ +abbrev Arguments := Array Argument + +/-- Tree of the arguments of every atom unfolded (expressions at the nodes). -/ +abbrev ArgumentsTree := Tree Arguments Unit + + +/- Curvature. -/ + +/-- Curvatures in the atom tree. -/ +abbrev CurvatureTree := Tree Curvature Curvature + + +/- Background conditions. -/ + +/-- A background condition is just an expression of type `Prop`. -/ +abbrev BCond := Expr + +/-- An atom may have several background conditions. -/ +abbrev BConds := Array BCond + +/-- Background conditions of every atom in the atom tree. -/ +abbrev BCondsTree := Tree BConds BConds + + +/- Result of finding an atom. -/ + +/-- Four trees: atom data tree, arguments tree, curvature tree, and background condition tree. -/ +abbrev AtomDataTrees := GraphAtomDataTree × ArgumentsTree × CurvatureTree × BCondsTree + + +/- New variables and constraints. -/ + +/-- New variables introduced by every atom, as local declarations at the nodes. -/ +abbrev NewVarsTree := Tree (Array LocalDecl) Unit + +/-- Variables for the new constraints introduced by every atom, as local declarations at the +nodes. -/ +abbrev NewConstrVarsTree := Tree (Array LocalDecl) Unit + + +/- New constraints (expressions). -/ + +/-- All the new constraints introduced by each atom, as expressions at the node. -/ +abbrev NewConstrsTree := Tree (Array Expr) Unit + + +/- Variable condtions. -/ + +/-- In general a variable condition can either match a constraint exactly or it can be inferred from +the other constraints, we represent that as a sum type. If a number is given, then it is interpreted +as the index of the constraint that corresponds to the vcondition. -/ +abbrev PreVCond := Nat ⊕ Expr + +/- An atom may have several variable conditions. -/ +abbrev PreVConds := Array PreVCond + +/- All the vconditions for all the atoms in the atom tree. -/ +abbrev PreVCondsTree := Tree PreVConds Unit + +/-- This is different from `PreVCond` and instead holds proofs of variable conditions. Still, it +could either be a variable corresponding to a constraint, or an expression involving several +constraints, in case it is inferred. -/ +abbrev VCond := Expr + +/-- Several proofs of vconditions. -/ +abbrev VConds := Array VCond + +/-- All the vcondition proofs for the atom. -/ +abbrev VCondsTree := Tree VConds Unit + +/-- It is useful to only consider the vconditions that need to be eliminated because they match a +constraint exactly. This represents the index of such constraint. -/ +abbrev VCondIdx := Nat + +/-- Several `VCondIdx`s. -/ +abbrev VCondsIdxs := Array VCondIdx + +/-- All the indices of all the vconditions marked for elimination. -/ +abbrev VCondsIdxsTree := Tree VCondsIdxs Unit + + +/- Solutions. -/ + +/-- A solution is just an expression (involving the atom arguments). -/ +abbrev NewVarAssignment := Expr + +/-- If an atom has several implementation variables, we need a solution for each of them. -/ +abbrev Solution := Array NewVarAssignment + + +/- Canonized expressions. -/ + +/-- A canonized expression. -/ +abbrev CanonExpr := Expr + +/-- A tree of canonized expressions. -/ +abbrev CanonExprsTree := Tree CanonExpr CanonExpr + +/-- Canonized expression and solution put together.-/ +structure CanonExprWithSolution where + canonExpr : CanonExpr + solution : Solution +deriving Inhabited + +/-- Tree of canonized expression and solution for each atom. -/ +abbrev CanonExprsWithSolutionTree := Tree CanonExprWithSolution CanonExprWithSolution + + +/- Proofs of solution equals atom expression. -/ + +/-- A proof of solution-equals-atom for an atom.-/ +abbrev SolEqAtomProof := Expr + +/-- Tree of all the solution-equals-atom proofs in the atom tree. -/ +abbrev SolEqAtomProofsTree := Tree SolEqAtomProof SolEqAtomProof + + +/- Proofs of feasibility. -/ + +/-- A feasibility proof. -/ +abbrev FeasibilityProof := Expr + +/-- An atom may have several feasibility proofs, one per implementation constraint. -/ +abbrev FeasibilityProofs := Array FeasibilityProof + +/-- Tree of all the feasibility proofs in the atom tree. -/ +abbrev FeasibilityProofsTree := Tree FeasibilityProofs Unit + + +/- Proofs of optimality and variable condition elimination. -/ + +/-- Optimality proof. There is only one per atom. -/ +abbrev OptimalityProof := Expr + +/-- Tree of all the optimality proofs in the atom tree. -/ +abbrev OptimalityProofsTree := Tree OptimalityProof OptimalityProof + +/-- A proof of vcondtion elimination. -/ +abbrev VCondElimProof := Expr + +/-- An atom may have several vconditions. Each needs a vcondition elimination proof. -/ +abbrev VCondElimProofs := Array VCondElimProof + +/-- Optimality and vcondition elimination proofs together. -/ +abbrev OptimalityAndVCondElimProofs := OptimalityProof × VCondElimProofs + +/-- Tree of all the optimality and vcondition elimination proofs in the atom tree. -/ +abbrev OptimalityAndVCondElimProofsTree := + Tree OptimalityAndVCondElimProofs OptimalityAndVCondElimProofs + +/-- Vcondition elimination map, from constraint index to the constraint expression. -/ +abbrev VCondElimMap := Std.HashMap Nat Expr + + +/- Processed atom tree. -/ + +/-- This structure holds all the necessary information to build a canonized problem and a proof of +equivalence. In `CvxLean/Tactic/DCP/DCPMakers.lean`, we define the procedure that takes an +optimization problem, builds an atom tree and "processes" it to make a processed atom tree. -/ +structure ProcessedAtomTree where + originalVarsDecls : Array LocalDecl + originalConstrVars : Array LocalDecl + newVarDecls : List LocalDecl + newConstrs : Array Expr + newConstrVarsArray : Array LocalDecl + forwardImagesNewVars : Array Expr + constraints : List (Name × Expr) + isVCond : Array Bool + vcondElimMap : VCondElimMap + solEqAtom : OC SolEqAtomProofsTree + feasibility : OC FeasibilityProofsTree + canonExprs : OC CanonExprsTree + optimality : OC OptimalityProofsTree +deriving Inhabited + +end DCP + +end CvxLean diff --git a/CvxLean/Tactic/DCP/DiscrTree.lean b/CvxLean/Tactic/DCP/DiscrTree.lean index 3ca7a186..90296d72 100644 --- a/CvxLean/Tactic/DCP/DiscrTree.lean +++ b/CvxLean/Tactic/DCP/DiscrTree.lean @@ -1,6 +1,9 @@ import Lean --- This is a copy of Lean's DiscrTree, but with all definitional equality checks removed. +/-! +This is a copy of Lean's DiscrTree, but with all definitional equality checks removed. Used to store +and match atoms, also in real-to-float. +-/ namespace CvxLean @@ -83,8 +86,9 @@ def empty : DiscrTree α := { root := {} } partial def Trie.format [ToMessageData α] : Trie α → MessageData | Trie.node vs cs => MessageData.group $ MessageData.paren $ - "node" ++ (if vs.isEmpty then MessageData.nil else " " ++ toMessageData vs) - ++ MessageData.joinSep (cs.toList.map $ fun ⟨k, c⟩ => MessageData.paren (toMessageData k ++ " => " ++ format c)) "," + "node" ++ (if vs.isEmpty then MessageData.nil else " " ++ toMessageData vs) ++ + MessageData.joinSep (cs.toList.map $ fun ⟨k, c⟩ => + MessageData.paren (toMessageData k ++ " => " ++ format c)) "," instance [ToMessageData α] : ToMessageData (Trie α) := ⟨Trie.format⟩ @@ -143,7 +147,8 @@ private def ignoreArg (a : Expr) (i : Nat) (infos : Array Meta.ParamInfo) : Meta else isProof a -private partial def pushArgsAux (infos : Array Meta.ParamInfo) : Nat → Expr → Array Expr → MetaM (Array Expr) +private partial def pushArgsAux (infos : Array Meta.ParamInfo) : + Nat → Expr → Array Expr → MetaM (Array Expr) | i, Expr.app f a, todo => do if (← ignoreArg a i infos) then pushArgsAux infos (i-1) f (todo.push tmpStar) @@ -157,7 +162,7 @@ def mkNoindexAnnotation (e : Expr) : Expr := def hasNoindexAnnotation (e : Expr) : Bool := annotation? `noindex e |>.isSome -private def pushArgs (root : Bool) (todo : Array Expr) (e : Expr) : MetaM (Key × Array Expr) := do +private def pushArgs (_root : Bool) (todo : Array Expr) (e : Expr) : MetaM (Key × Array Expr) := do if hasNoindexAnnotation e then return (Key.star, todo) else @@ -246,7 +251,7 @@ def insert [BEq α] (d : DiscrTree α) (e : Expr) (v : α) : MetaM (DiscrTree α let keys ← mkPath e return d.insertCore keys v -private def getKeyArgs (e : Expr) (isMatch root : Bool) : MetaM (Key × Array Expr) := do +private def getKeyArgs (e : Expr) (isMatch _root : Bool) : MetaM (Key × Array Expr) := do match e.getAppFn with | Expr.lit v => return (Key.lit v, #[]) | Expr.const c _ => @@ -255,7 +260,7 @@ private def getKeyArgs (e : Expr) (isMatch root : Bool) : MetaM (Key × Array Ex | Expr.fvar fvarId => let nargs := e.getAppNumArgs return (Key.fvar fvarId nargs, e.getAppRevArgs) - | Expr.mvar mvarId => + | Expr.mvar _mvarId => if isMatch then return (Key.other, #[]) else do @@ -271,10 +276,10 @@ private def getKeyArgs (e : Expr) (isMatch root : Bool) : MetaM (Key × Array Ex return (Key.other, #[]) private abbrev getMatchKeyArgs (e : Expr) (root : Bool) : MetaM (Key × Array Expr) := - getKeyArgs e (isMatch := true) (root := root) + getKeyArgs e (isMatch := true) (_root := root) private abbrev getUnifyKeyArgs (e : Expr) (root : Bool) : MetaM (Key × Array Expr) := - getKeyArgs e (isMatch := false) (root := root) + getKeyArgs e (isMatch := false) (_root := root) private def getStarResult (d : DiscrTree α) : Array α := let result : Array α := Array.mkEmpty initCapacity @@ -285,7 +290,8 @@ private def getStarResult (d : DiscrTree α) : Array α := private abbrev findKey (cs : Array (Key × Trie α)) (k : Key) : Option (Key × Trie α) := cs.binSearch (k, default) (fun a b => a.1 < b.1) -private partial def getMatchLoop (todo : Array Expr) (c : Trie α) (result : Array α) : MetaM (Array α) := do +private partial def getMatchLoop (todo : Array Expr) (c : Trie α) (result : Array α) : + MetaM (Array α) := do match c with | Trie.node vs cs => if todo.isEmpty then @@ -313,13 +319,15 @@ private partial def getMatchLoop (todo : Array Expr) (c : Trie α) (result : Arr match k with | Key.star => pure result /- - Recall that dependent arrows are `(Key.other, #[])`, and non-dependent arrows are `(Key.arrow, #[a, b])`. - A non-dependent arrow may be an instance of a dependent arrow (stored at `DiscrTree`). Thus, we also visit the `Key.other` child. + Recall that dependent arrows are `(Key.other, #[])`, and non-dependent arrows are + `(Key.arrow, #[a, b])`. A non-dependent arrow may be an instance of a dependent arrow + (stored at `DiscrTree`). Thus, we also visit the `Key.other` child. -/ | Key.arrow => visitNonStar Key.other #[] (← visitNonStar k args result) | _ => visitNonStar k args result -private def getMatchRoot (d : DiscrTree α) (k : Key) (args : Array Expr) (result : Array α) : MetaM (Array α) := +private def getMatchRoot (d : DiscrTree α) (k : Key) (args : Array Expr) (result : Array α) : + MetaM (Array α) := match d.root.find? k with | none => return result | some c => getMatchLoop args c result @@ -348,7 +356,7 @@ partial def getUnify (d : DiscrTree α) (e : Expr) : MetaM (Array α) := do where process (skip : Nat) (todo : Array Expr) (c : Trie α) (result : Array α) : MetaM (Array α) := do match skip, c with - | skip+1, Trie.node vs cs => + | skip + 1, Trie.node _vs cs => if cs.isEmpty then return result else @@ -417,7 +425,7 @@ def foldM [Monad m] (f : σ → Array Key → α → m σ) (init : σ) (t : Disc def fold (f : σ → Array Key → α → σ) (init : σ) (t : DiscrTree α) : σ := Id.run $ t.foldM (init := init) λ s keys a => return f s keys a --- TODO inefficient since it doesn't take advantage of the Trie structure at all +-- TODO: inefficient since it doesn't take advantage of the Trie structure at all @[inline] def merge [BEq α] (t u : DiscrTree α) : DiscrTree α := if t.root.size < u.root.size then loop t u else loop u t diff --git a/CvxLean/Tactic/DCP/OC.lean b/CvxLean/Tactic/DCP/OC.lean index aaca13e5..4f1797dd 100644 --- a/CvxLean/Tactic/DCP/OC.lean +++ b/CvxLean/Tactic/DCP/OC.lean @@ -1,105 +1,136 @@ import Lean +/-! +This file defines the `OC` type, which is a generic type to split data into objective function data +and constraints data. It is used everywhere in DCP (also in pre-DCP). +-/ + open Lean Lean.Meta namespace CvxLean -structure OC (α : Type) := - (objFun : α) - (constr : Array α) +/-- Structure to split data into the objective funtion (`objFun`) and the constraints (`constrs`). +-/ +structure OC (α : Type) where + objFun : α + constrs : Array α deriving Inhabited namespace OC variable {α β γ δ ε ζ η θ} +variable [Inhabited α] [Inhabited β] [Inhabited γ] [Inhabited δ] [Inhabited ε] [Inhabited ζ] +variable [Inhabited η] + +/-! Many maps operating on up to 7 `OC`s. -/ + def map (f : α → β) (oc : OC α) : OC β := - OC.mk (f oc.objFun) (oc.constr.map f) + OC.mk (f oc.objFun) (oc.constrs.map f) def mapM (f : α → MetaM β) (oc : OC α) : MetaM (OC β) := do - return OC.mk (← f oc.objFun) (← oc.constr.mapM f) + return OC.mk (← f oc.objFun) (← oc.constrs.mapM f) -def map2M [Inhabited α] [Inhabited β] (f : α → β → MetaM γ) (a : OC α) (b : OC β) : MetaM (OC γ) := do +def map2M (f : α → β → MetaM γ) (a : OC α) (b : OC β) : MetaM (OC γ) := do let objFun ← f a.objFun b.objFun - if a.constr.size != b.constr.size then throwError "different number of constrs" - let mut constr := #[] - for i in [:a.constr.size] do - constr := constr.push $ ← f a.constr[i]! b.constr[i]! - return OC.mk objFun constr + if a.constrs.size != b.constrs.size then + throwError "`OC` error: could not map (2), different number of constraints." + let mut constrs := #[] + for i in [:a.constrs.size] do + constrs := constrs.push <| ← f a.constrs[i]! b.constrs[i]! + return OC.mk objFun constrs -def map3M [Inhabited α] [Inhabited β] [Inhabited γ] (f : α → β → γ → MetaM δ) (a : OC α) (b : OC β) (c : OC γ) : MetaM (OC δ) := do +def map3M (f : α → β → γ → MetaM δ) (a : OC α) (b : OC β) (c : OC γ) : MetaM (OC δ) := do let objFun ← f a.objFun b.objFun c.objFun - if a.constr.size != b.constr.size ∨ b.constr.size != c.constr.size then throwError "different number of constrs" - let mut constr := #[] - for i in [:a.constr.size] do - constr := constr.push $ ← f a.constr[i]! b.constr[i]! c.constr[i]! - return OC.mk objFun constr + if a.constrs.size != b.constrs.size || b.constrs.size != c.constrs.size then + throwError "`OC` error: could not map (3), different number of constraints." + let mut constrs := #[] + for i in [:a.constrs.size] do + constrs := constrs.push <| ← f a.constrs[i]! b.constrs[i]! c.constrs[i]! + return OC.mk objFun constrs -def map4M [Inhabited α] [Inhabited β] [Inhabited γ] [Inhabited δ] (f : α → β → γ → δ → MetaM ε) (a : OC α) (b : OC β) (c : OC γ) (d : OC δ) : MetaM (OC ε) := do +def map4M (f : α → β → γ → δ → MetaM ε) (a : OC α) (b : OC β) (c : OC γ) (d : OC δ) : + MetaM (OC ε) := do let objFun ← f a.objFun b.objFun c.objFun d.objFun - if a.constr.size != b.constr.size ∨ b.constr.size != c.constr.size ∨ c.constr.size != d.constr.size then throwError "different number of constrs" - let mut constr := #[] - for i in [:a.constr.size] do - constr := constr.push $ ← f a.constr[i]! b.constr[i]! c.constr[i]! d.constr[i]! - return OC.mk objFun constr + if a.constrs.size != b.constrs.size || b.constrs.size != c.constrs.size + || c.constrs.size != d.constrs.size then + throwError "`OC` error: could not map (4), different number of constraints." + let mut constrs := #[] + for i in [:a.constrs.size] do + constrs := constrs.push <| ← f a.constrs[i]! b.constrs[i]! c.constrs[i]! d.constrs[i]! + return OC.mk objFun constrs -def map5M [Inhabited α] [Inhabited β] [Inhabited γ] [Inhabited δ] [Inhabited ε] (f : α → β → γ → δ → ε → MetaM ζ) (a : OC α) (b : OC β) (c : OC γ) (d : OC δ) (e : OC ε) : MetaM (OC ζ) := do +def map5M (f : α → β → γ → δ → ε → MetaM ζ) (a : OC α) (b : OC β) (c : OC γ) (d : OC δ) (e : OC ε) : + MetaM (OC ζ) := do let objFun ← f a.objFun b.objFun c.objFun d.objFun e.objFun - if a.constr.size != b.constr.size ∨ b.constr.size != c.constr.size ∨ c.constr.size != d.constr.size ∨ d.constr.size != e.constr.size then throwError "different number of constrs" - let mut constr := #[] - for i in [:a.constr.size] do - constr := constr.push $ ← f a.constr[i]! b.constr[i]! c.constr[i]! d.constr[i]! e.constr[i]! - return OC.mk objFun constr + if a.constrs.size != b.constrs.size || b.constrs.size != c.constrs.size + || c.constrs.size != d.constrs.size || d.constrs.size != e.constrs.size then + throwError "`OC` error: could not map (5), different number of constraints." + let mut constrs := #[] + for i in [:a.constrs.size] do + constrs := constrs.push <| + ← f a.constrs[i]! b.constrs[i]! c.constrs[i]! d.constrs[i]! e.constrs[i]! + return OC.mk objFun constrs -def map6M [Inhabited α] [Inhabited β] [Inhabited γ] [Inhabited δ] [Inhabited ε] [Inhabited ζ] (h : α → β → γ → δ → ε → ζ → MetaM η) (a : OC α) (b : OC β) (c : OC γ) (d : OC δ) (e : OC ε) (f : OC ζ) : MetaM (OC η) := do +def map6M (h : α → β → γ → δ → ε → ζ → MetaM η) (a : OC α) (b : OC β) (c : OC γ) (d : OC δ) + (e : OC ε) (f : OC ζ) : MetaM (OC η) := do let objFun ← h a.objFun b.objFun c.objFun d.objFun e.objFun f.objFun - if a.constr.size != b.constr.size ∨ b.constr.size != c.constr.size ∨ c.constr.size != d.constr.size ∨ d.constr.size != e.constr.size ∨ e.constr.size != f.constr.size then throwError "different number of constrs" - let mut constr := #[] - for i in [:a.constr.size] do - constr := constr.push $ ← h a.constr[i]! b.constr[i]! c.constr[i]! d.constr[i]! e.constr[i]! f.constr[i]! - return OC.mk objFun constr - -def map7M [Inhabited α] [Inhabited β] [Inhabited γ] [Inhabited δ] [Inhabited ε] [Inhabited ζ] [Inhabited η] - (h : α → β → γ → δ → ε → ζ → η → MetaM θ) - (a : OC α) (b : OC β) (c : OC γ) (d : OC δ) (e : OC ε) (f : OC ζ) (g : OC η) : MetaM (OC θ) := do + if a.constrs.size != b.constrs.size || b.constrs.size != c.constrs.size + || c.constrs.size != d.constrs.size || d.constrs.size != e.constrs.size + || e.constrs.size != f.constrs.size then + throwError "`OC` error: could not map (6), different number of constraints." + let mut constrs := #[] + for i in [:a.constrs.size] do + constrs := constrs.push <| + ← h a.constrs[i]! b.constrs[i]! c.constrs[i]! d.constrs[i]! e.constrs[i]! f.constrs[i]! + return OC.mk objFun constrs + +def map7M (h : α → β → γ → δ → ε → ζ → η → MetaM θ) (a : OC α) (b : OC β) (c : OC γ) (d : OC δ) + (e : OC ε) (f : OC ζ) (g : OC η) : MetaM (OC θ) := do let objFun ← h a.objFun b.objFun c.objFun d.objFun e.objFun f.objFun g.objFun - if a.constr.size != b.constr.size ∨ b.constr.size != c.constr.size ∨ c.constr.size != d.constr.size ∨ d.constr.size != e.constr.size ∨ e.constr.size != f.constr.size ∨ f.constr.size != g.constr.size then throwError "different number of constrs" - let mut constr := #[] - for i in [:a.constr.size] do - constr := constr.push $ ← h a.constr[i]! b.constr[i]! c.constr[i]! d.constr[i]! e.constr[i]! f.constr[i]! g.constr[i]! - return OC.mk objFun constr + if a.constrs.size != b.constrs.size || b.constrs.size != c.constrs.size + || c.constrs.size != d.constrs.size || d.constrs.size != e.constrs.size + || e.constrs.size != f.constrs.size || f.constrs.size != g.constrs.size then + throwError "`OC` error: could not map (7), different number of constraints." + let mut constrs := #[] + for i in [:a.constrs.size] do + constrs := constrs.push <| + ← h a.constrs[i]! b.constrs[i]! c.constrs[i]! d.constrs[i]! e.constrs[i]! f.constrs[i]! + g.constrs[i]! + return OC.mk objFun constrs def mapMwithCounter (f : Nat → α → MetaM (β × Nat)) (oc : OC α) : MetaM (OC β) := do let mut n := 0 let (objFun, n') ← f n oc.objFun n := n' let mut constrs := #[] - for constr in oc.constr do + for constr in oc.constrs do let (c, n') ← f n constr constrs := constrs.push c n := n' return OC.mk objFun constrs -def map2MwithCounter [Inhabited α] [Inhabited β] (f : Nat → α → β → MetaM (γ × Nat)) (a : OC α) (b : OC β) : MetaM (OC γ) := do - if a.constr.size != b.constr.size then throwError "different number of constrs" +def map2MwithCounter (f : Nat → α → β → MetaM (γ × Nat)) (a : OC α) (b : OC β) : MetaM (OC γ) := do + if a.constrs.size != b.constrs.size then + throwError "`OC` error: could not map (2C), different number of constraints." let mut n := 0 let (objFun, n') ← f n a.objFun b.objFun n := n' let mut constrs := #[] - for i in [:a.constr.size] do - let (c, n') ← f n a.constr[i]! b.constr[i]! + for i in [:a.constrs.size] do + let (c, n') ← f n a.constrs[i]! b.constrs[i]! constrs := constrs.push c n := n' return OC.mk objFun constrs def fold (oc : OC α) (init : β) (f : β → α → β) : β := Id.run do let b ← f init oc.objFun - let b ← oc.constr.foldl f b + let b ← oc.constrs.foldl f b return b -instance [ToMessageData α] : ToMessageData (OC α) := { - toMessageData := fun oc => "ObjFun:" ++ toMessageData oc.objFun ++ "\nConstr:" ++ toMessageData oc.constr -} +instance [ToMessageData α] : ToMessageData (OC α) where + toMessageData := fun oc => + "objFun:" ++ toMessageData oc.objFun ++ "\nconstrs:" ++ toMessageData oc.constrs end OC diff --git a/CvxLean/Tactic/DCP/Tree.lean b/CvxLean/Tactic/DCP/Tree.lean index 171ab64e..48c3780f 100644 --- a/CvxLean/Tactic/DCP/Tree.lean +++ b/CvxLean/Tactic/DCP/Tree.lean @@ -1,8 +1,13 @@ import Lean import CvxLean.Lib.Math.Data.Array +/-! +Simple trees that can have values of different types at the leaves and nodes. +-/ + namespace CvxLean +/-- Tree with node labels of type `α` and leaf labels of type `β`. -/ inductive Tree (α β : Type) : Type | node (val : α) (children : Array (Tree α β)) : Tree α β | leaf (val : β) : Tree α β @@ -14,7 +19,7 @@ variable {α β γ δ} open Lean -partial def toMessageData [ToMessageData α] [ToMessageData β] (x : Tree α β) : MessageData := match x with +partial def toMessageData [ToMessageData α] [ToMessageData β] : Tree α β → MessageData | Tree.node val children => let children := children.map toMessageData MessageData.paren <| @@ -25,30 +30,35 @@ partial def toMessageData [ToMessageData α] [ToMessageData β] (x : Tree α β) instance [ToMessageData α] [ToMessageData β] : ToMessageData (Tree α β) where toMessageData := toMessageData - -partial def zip [Inhabited α] [Inhabited β] [ToMessageData α] [ToMessageData β] [ToMessageData γ] [ToMessageData δ] : - Tree α γ → Tree β δ → MetaM (Tree (α × β) (γ × δ)) - | node val₁ children₁, node val₂ children₂ => do return node (val₁, val₂) (← Array.zipWithM zip children₁ children₂) - | leaf val₁, leaf val₂ => pure $ leaf (val₁, val₂) +/-- Zip two trees. -/ +partial def zip [Inhabited α] [Inhabited β] [ToMessageData α] [ToMessageData β] [ToMessageData γ] + [ToMessageData δ] : Tree α γ → Tree β δ → MetaM (Tree (α × β) (γ × δ)) + | node val₁ children₁, node val₂ children₂ => do + return node (val₁, val₂) (← Array.zipWithM zip children₁ children₂) + | leaf val₁, leaf val₂ => pure <| leaf (val₁, val₂) | t1, t2 => throwError "zipped trees do not match {t1} {t2}" +/-- Map two functions on the values at the nodes and the values at the leafs, respectively. -/ partial def map [Inhabited γ] (f : α → γ) (g : β → δ) : Tree α β → Tree γ δ -| node val children => node (f val) (children.map (map f g)) -| leaf val => leaf $ g val + | node val children => node (f val) (children.map (map f g)) + | leaf val => leaf <| g val +/-- Fold a function on the nodes of the tree. -/ partial def fold [Inhabited γ] (init : γ) (f : γ → α → γ) : Tree α β → γ | node val children => Id.run do - let mut res := init - for child in children do - res ← fold res f child - res := f res val - return res + let mut res := init + for child in children do + res ← fold res f child + res := f res val + return res | leaf _ => init +/-- For trees with the same type of nodes and leaves, retrieve the value. -/ def val : Tree α α → α | node val _ => val | leaf val => val +/-- Size of a tree. -/ partial def size : Tree α β → Nat | node _ children => children.foldl (fun acc child => acc + size child) 0 | leaf _ => 1 diff --git a/CvxLean/Tactic/PreDCP/Egg/EggTypes.lean b/CvxLean/Tactic/PreDCP/Egg/EggTypes.lean index 3607c4d6..4275a8ea 100644 --- a/CvxLean/Tactic/PreDCP/Egg/EggTypes.lean +++ b/CvxLean/Tactic/PreDCP/Egg/EggTypes.lean @@ -1,19 +1,89 @@ import Lean import CvxLean.Tactic.DCP.Tree +import CvxLean.Tactic.DCP.OC + +/-! +The main types user around the `egg` interface. +-/ namespace CvxLean -def EggTree := CvxLean.Tree String String +section Minimization + +/-- An intermediate structure to represent problems as trees of strings. -/ +def EggTree := Tree String String + +instance : Inhabited EggTree where + default := Tree.leaf "" + +/-- A more refined `EggTree`, split by its objective function and components. -/ +def EggOCTree := OC (String × EggTree) +/-- Helper structure to keep track of the arguments, some of which may be pre-determined, when +translating from atom trees to trees with constructors in the `egg` language. -/ +inductive EggTreeOpArgTag + | arg + | val (s : String) + +/-- Once `EggTree` is flattened into S-expressions, it can be turned into an `EggMinimization`. This +is what will by JSONified and sent to `egg`. It is split into objective functions and constraints, +with their tag. Note that `egg` builds one problem term from this. However, it is useful to keep +track of the names of the components so that when `egg` comes back with a sequence of rewrites it +can easily indicate the location of the rewrite. -/ structure EggMinimization where objFun : String constrs : List (String × String) -/-- Corresponds to the domain structure. -/ +/-- This encodes an interval, used for e-class analysis. For example `⟨"0", "1", "-inf", "8"⟩` +is how the interval `(-∞, 8]` would be encoded. -/ structure EggDomain where hi : String lo : String hi_open : String lo_open : String +/-- Every `EggDomain` is infered from a constraint named `constraintName`, and it corresponds to +a variable or parameter named `varOrParamName`. This structure keeps track of that. -/ +structure EggDomainIdentified where + constraintName : String + varOrParamName : String + domain : EggDomain + +/-- An `EggOCTree` with also the domain information. -/ +def EggOCTreeExtended := EggOCTree × Array EggDomainIdentified + +/-- A request consists of an `EggMinimization` and a list of domains per variable (or parameter). -/ +structure EggRequest where + domains : List (String × EggDomain) + target : EggMinimization + +end Minimization + +section Rewrites + +/-- Recall that, in `egg`, rewrites can be forward or backward (in the response). -/ +inductive EggRewriteDirection where + | Forward + | Backward +deriving Inhabited, DecidableEq + +/-- A rewrite returned by `egg`. -/ +structure EggRewrite where + /-- This is the name of the lemma used for rewriting. -/ + rewriteName : String + /-- Forward or backward. -/ + direction : EggRewriteDirection + /-- The name of the constraint or `"objFun"`. -/ + location : String + /-- Source sub-expression. -/ + subexprFrom : String + /-- Target sub-expression. -/ + subexprTo : String + /-- The context expression with a `"?"` in place of the sub-expression. For example, if we are + rewriting `x + (y + z)` into `x + (z + y)`, the expected term would be `x + ?`, with `subexprFrom` + set to `y + z` and `subexprTo` set to `z + y`. -/ + expectedTerm : String + +end Rewrites + end CvxLean diff --git a/CvxLean/Tactic/PreDCP/Egg/FromMinimization.lean b/CvxLean/Tactic/PreDCP/Egg/FromMinimization.lean index b76431d8..4818b2ed 100644 --- a/CvxLean/Tactic/PreDCP/Egg/FromMinimization.lean +++ b/CvxLean/Tactic/PreDCP/Egg/FromMinimization.lean @@ -1,22 +1,31 @@ import CvxLean.Tactic.DCP.Tree -import CvxLean.Tactic.PreDCP.UncheckedDCP +import CvxLean.Tactic.PreDCP.Egg.UncheckedDCP import CvxLean.Tactic.PreDCP.Egg.EggTypes -namespace CvxLean +/-! +# Minimization problem to Egg minimization (S-expressions) + +Using `CvxLean/Tactic/PreDCP/Egg/UncheckedDCP.lean`, we build an atom tree from a given problem, +ignoring the curvature checks. This file does all the processing needed after that point to turn +it into a problem that can be sent to `egg`. This includes: +* Adjusting the name of the constructors. +* Building the domain for each variable in normalized linear constraint to initialize the e-class + analyses. +* Flattening the trees to strings. +-/ -section Egg.FromMinimization +namespace CvxLean -open Lean +open Lean Meta -inductive EggTree.OpArgTag - | arg - | val (s : String) +namespace EggTree -/-- Mapping between atom names that come from the unchecked tree construction -and their egg counterparts. `prob`, `objFun`, `constr` and `constrs` are special -cases. The rest of names come from the atom library. It also returns the arity -of the operation and some extra arguments. -/ -def EggTree.opMap : HashMap String (String × Array EggTree.OpArgTag) := +/-- Mapping between atom names that come from the unchecked tree construction and their `egg` +counterparts. `prob`, `objFun`, `constr` and `constrs` are special cases. The rest of names come +from the atom library. It also returns the extra arguments, which sometimes may be pre-determined, +for example, the atom `powNegOne` corresponds to the constructor `pow` in the `egg` language with +the second argumetn set to `-1`. -/ +def opMap : HashMap String (String × Array EggTreeOpArgTag) := HashMap.ofList [ ("prob", ("prob", #[.arg, .arg])), ("objFun", ("objFun", #[.arg])), @@ -55,12 +64,11 @@ def EggTree.opMap : HashMap String (String × Array EggTree.OpArgTag) := ("norm2₂", ("norm2", #[.arg, .arg])) ] -/-- Traverse the tree and use `EggTree.opMap` to align the names of the -constructors. -/ -partial def EggTree.adjustOps (t : Tree String String) : MetaM (Tree String String) := do +/-- Traverse the tree and use `opMap` to align the names of the constructors. -/ +partial def adjustOps (t : EggTree) : MetaM EggTree := do match t with | Tree.node op children => - if let some (newOp, argTags) := EggTree.opMap.find? op then + if let some (newOp, argTags) := opMap.find? op then let mut children ← children.mapM adjustOps let argsGiven := children.size let mut newChildren := #[] @@ -86,17 +94,12 @@ partial def EggTree.adjustOps (t : Tree String String) : MetaM (Tree String Stri | Tree.leaf "unknown" => throwError "Unknown atom." | l => return l -def _root_.Option.map2 {α β γ} (f : α → β → γ) : Option α → Option β → Option γ - | some a, some b => some <| f a b - | _, _ => none - -/-- Extract the numerical value from the egg tree, if any. This is used to -construct the domain information needed in the e-class analyses. -NOTE(RFM): This is a source of inaccuracy as some operations are approximate, -however, it does not compromise the soundness of the procedure. If the domains -sent to egg are incorrect, which may lead to an incorrect sequence of steps to, -transform the problem, Lean will reject the proof. -/ -partial def EggTree.getNumericalValue? : EggTree → Option Float +/-- Extract the numerical value from the egg tree, if any. This is used to construct the domain +information needed in the e-class analyses. +NOTE: This is a source of inaccuracy as some operations are approximate, however, it does not +compromise the soundness of the procedure. If the domains sent to egg are incorrect, which may lead +to an incorrect sequence of steps to, transform the problem, Lean will reject the proof. -/ +partial def getNumericalValue? : EggTree → Option Float | Tree.leaf s => Option.map Float.ofInt s.toInt? | Tree.node "neg" #[(t : EggTree)] => @@ -108,26 +111,52 @@ partial def EggTree.getNumericalValue? : EggTree → Option Float | Tree.node "exp" #[(t : EggTree)] => Option.map Float.exp t.getNumericalValue? | Tree.node "add" #[(t1 : EggTree), (t2 : EggTree)] => - Option.map2 Float.add t1.getNumericalValue? t2.getNumericalValue? + Option.map₂ Float.add t1.getNumericalValue? t2.getNumericalValue? | Tree.node "sub" #[(t1 : EggTree), (t2 : EggTree)] => - Option.map2 Float.sub t1.getNumericalValue? t2.getNumericalValue? + Option.map₂ Float.sub t1.getNumericalValue? t2.getNumericalValue? | Tree.node "mul" #[(t1 : EggTree), (t2 : EggTree)] => - Option.map2 Float.mul t1.getNumericalValue? t2.getNumericalValue? + Option.map₂ Float.mul t1.getNumericalValue? t2.getNumericalValue? | Tree.node "div" #[(t1 : EggTree), (t2 : EggTree)] => - Option.map2 Float.div t1.getNumericalValue? t2.getNumericalValue? + Option.map₂ Float.div t1.getNumericalValue? t2.getNumericalValue? | Tree.node "pow" #[(t1 : EggTree), (t2 : EggTree)] => - Option.map2 Float.pow t1.getNumericalValue? t2.getNumericalValue? + Option.map₂ Float.pow t1.getNumericalValue? t2.getNumericalValue? | _ => none -/-- -/ -def EggTree.getIdentifier? (vars : List String) : EggTree → Option String +/-- If the tree corresponds to a variable or parameter, return its name. -/ +def getIdentifier? (vars : List String) : EggTree → Option String | Tree.node "var" #[Tree.leaf s] => if s ∈ vars then some s else none | Tree.node "param" #[Tree.leaf s] => if s ∈ vars then some s else none | _ => none +/-- Flatten an `EggTree` to a string to send to egg. -/ +partial def toEggString : EggTree → String + | Tree.node n children => + let childrenStr := (children.map EggTree.toEggString).data + "(" ++ n ++ " " ++ (" ".intercalate childrenStr) ++ ")" + | Tree.leaf n => n + +/-- Size of the expresison's AST that an `EggTree` encodes.. -/ +partial def size : EggTree → Nat + | Tree.node "var" _ => 1 + | Tree.node _ children => 1 + (children.map size).foldl Nat.add 0 + | Tree.leaf _ => 1 + +/-- Compose the components of an `EggOCTree` into a single `EggTree`. -/ +def ofEggOCTree (ocTree : EggOCTree) : EggTree := + let objFun := ocTree.objFun.2 + let constrs := ocTree.constrs.map + fun (h, c) => Tree.node "constr" #[Tree.leaf h, c] + let objFunNode := Tree.node "objFun" #[objFun] + let constrNode := Tree.node "constrs" constrs + Tree.node "prob" #[objFunNode, constrNode] + +end EggTree + +namespace EggOCTreeExtended + /-- Helper function for `fromComponents` that processes a single domain-defining expression. -/ -def ExtendedEggTree.processDomainExprTree (h : String) (c : Tree String String) - (vars : List String) : Option (String × String × EggDomain) := +def processDomainExprTree (h : String) (c : Tree String String) (vars : List String) : + Option EggDomainIdentified := -- Detect domain constraints. We capture the following cases: -- * x ≤ f and f ≤ x, -- * x < f and f < x, @@ -137,38 +166,37 @@ def ExtendedEggTree.processDomainExprTree (h : String) (c : Tree String String) | Tree.node "le" #[(tl : EggTree), (tr : EggTree)] => match (tl.getIdentifier? vars, tr.getNumericalValue?) with -- v ∈ (-inf, f] - | (some v, some f) => some (h, v, ⟨"-inf", f.toString, "1", "0"⟩) + | (some v, some f) => some ⟨h, v, ⟨"-inf", f.toString, "1", "0"⟩⟩ | _ => match (tl.getNumericalValue?, tr.getIdentifier? vars) with -- v ∈ [f, inf) - | (some f, some v) => some (h, v, ⟨f.toString, "inf", "0", "1"⟩) + | (some f, some v) => some ⟨h, v, ⟨f.toString, "inf", "0", "1"⟩⟩ | _ => none | Tree.node "lt" #[(tl : EggTree), (tr : EggTree)] => match (tl.getIdentifier? vars, tr.getNumericalValue?) with -- v ∈ (-inf, f) - | (some v, some f) => some (h, v, ⟨"-inf", f.toString, "1", "1"⟩) + | (some v, some f) => some ⟨h, v, ⟨"-inf", f.toString, "1", "1"⟩⟩ | _ => match (tl.getNumericalValue?, tr.getIdentifier? vars) with -- v ∈ (f, inf) - | (some f, some v) => some (h, v, ⟨f.toString, "inf", "1", "1"⟩) + | (some f, some v) => some ⟨h, v, ⟨f.toString, "inf", "1", "1"⟩⟩ | _ => none | Tree.node "eq" #[(tl : EggTree), (tr : EggTree)] => match (tl.getIdentifier? vars, tr.getNumericalValue?) with -- v ∈ [f, f] - | (some v, some f) => some (h, v, ⟨f.toString, f.toString, "0", "0"⟩) + | (some v, some f) => some ⟨h, v, ⟨f.toString, f.toString, "0", "0"⟩⟩ | _ => match (tl.getNumericalValue?, tr.getIdentifier? vars) with -- v ∈ [f, f] - | (some f, some v) => some (h, v, ⟨f.toString, f.toString, "0", "0"⟩) + | (some f, some v) => some ⟨h, v, ⟨f.toString, f.toString, "0", "0"⟩⟩ | _ => none | _ => none -/-- Given an expression representing a minimization problem, turn it into a -tree of strings, also extract all the domain information for single varibales, -in particular positivity and nonnegativity constraints. -/ -def ExtendedEggTree.fromComponents (objFun : String × Tree String String) - (constr : Array (String × Tree String String)) (vars : List String) : - MetaM (OC (String × EggTree) × Array (String × String × EggDomain)) := do +/-- Given an expression representing a minimization problem, turn it into a tree of strings, also +extract all the domain information from simple constraints. Adjust operations to match the names +of the constructors in the `egg` language. -/ +def fromComponents (objFun : String × EggTree) (constr : Array (String × EggTree)) + (vars : List String) : MetaM EggOCTreeExtended := do let mut domainConstrs := #[] let mut i := 0 for (h, c) in constr do @@ -180,47 +208,29 @@ def ExtendedEggTree.fromComponents (objFun : String × Tree String String) i := i + 1 -- Surround variables and adjust operations in the whole tree. - let ocTree : OC (String × Tree String String) := { + let ocTree : OC (String × EggTree) := { objFun := (objFun.1, ← EggTree.adjustOps objFun.2) - constr := ← constr.mapIdxM <| - -- TODO: some constraints may have the same name, so we add the index. + constrs := ← constr.mapIdxM <| + -- NOTE: some constraints may have the same name, so we add the index. fun i (h, c) => return (s!"{i}:" ++ h, ← EggTree.adjustOps c) } return (ocTree, domainConstrs) -/-- Apply `ExtendedEggTree.fromComponents` to a minimization expression, processing it using the -unchecked DCP procedure. -/ -def ExtendedEggTree.fromMinimization (e : Meta.MinimizationExpr) (vars : List String) : - MetaM (OC (String × EggTree) × Array (String × String × EggDomain)) := do +/-- Get the atom tree form a minimization problem. Then call `fromComponents`, which takes care of +adjusting all nodes and extracting the domain information from the constraints. -/ +def fromMinimization (e : MinimizationExpr) (vars : List String) : + MetaM EggOCTreeExtended := do let ocTree ← UncheckedDCP.uncheckedTreeFromMinimizationExpr e - fromComponents ocTree.objFun ocTree.constr vars - -/-- Flatten an `EggTree` to a string to send to egg. -/ -partial def EggTree.toEggString : Tree String String → String - | Tree.node n children => - let childrenStr := (children.map EggTree.toEggString).data - "(" ++ n ++ " " ++ (" ".intercalate childrenStr) ++ ")" - | Tree.leaf n => n + fromComponents ocTree.objFun ocTree.constrs vars -/-- Size of the AST. -/ -partial def EggTree.size : EggTree → Nat - | Tree.node "var" _ => 1 - | Tree.node _ children => 1 + (children.map EggTree.size).foldl Nat.add 0 - | Tree.leaf _ => 1 +end EggOCTreeExtended -/-- -/ -def EggTree.ofOCTree (ocTree : OC (String × Tree String String)) : Tree String String := - let objFun := ocTree.objFun.2 - let constrs := ocTree.constr.map - fun (h, c) => Tree.node "constr" #[Tree.leaf h, c] - let objFunNode := Tree.node "objFun" #[objFun] - let constrNode := Tree.node "constrs" constrs - Tree.node "prob" #[objFunNode, constrNode] +namespace EggMinimization -/-- Convert `OC` tree to `EggMinimization`. -/ -def EggMinimization.ofOCTree (oc : OC (String × EggTree)) : EggMinimization := +/-- Convert `EggOCTree `EggMinimization` by flattening all trees. -/ +def ofEggOCTree (oc : EggOCTree) : EggMinimization := { objFun := EggTree.toEggString oc.objFun.2, - constrs := Array.data <| oc.constr.map fun (h, c) => (h, EggTree.toEggString c) } + constrs := Array.data <| oc.constrs.map fun (h, c) => (h, EggTree.toEggString c) } -end Egg.FromMinimization +end EggMinimization end CvxLean diff --git a/CvxLean/Tactic/PreDCP/Egg/Runner.lean b/CvxLean/Tactic/PreDCP/Egg/Runner.lean index 6bbc6cd9..4f286479 100644 --- a/CvxLean/Tactic/PreDCP/Egg/Runner.lean +++ b/CvxLean/Tactic/PreDCP/Egg/Runner.lean @@ -1,12 +1,20 @@ import Lean import CvxLean.Tactic.PreDCP.Egg.EggTypes +/-! +# Call the `egg` sub-process and get a list of rewrites + +This file defines the `runEggRequest` function, which is used to call the `egg` sub-process. Some +pre-processing is necessary to encode domains and minimization problems into the JSON. Also, some +post-processing is needed to parse the JSON response. + +Taken from / +-/ + namespace CvxLean open Lean --- Taken from https://github.com/opencompl/egg-tactic-code - def _root_.Lean.Json.getStr! (j : Json) : String := match j with | Json.str a => a @@ -19,7 +27,7 @@ def _root_.Lean.Json.getArr! (j : Json) : Array Json := variable {ε α} -def _root_.MetaM.ofExcept [ToString ε]: Except ε α -> MetaM α := +def _root_.MetaM.ofExcept [ToString ε]: Except ε α → MetaM α := fun e => match e with | Except.error msg => throwError (toString msg) @@ -33,7 +41,7 @@ instance : MonadExceptOf String MetaM := { def surroundQuotes (s : String) : String := "\"" ++ s ++ "\"" -/-- TODO +/-- JSONify an EggMinimization. NOTE: Tuples are lists of two elements. -/ def EggMinimization.toJson (e : EggMinimization) : String := "{" ++ @@ -54,11 +62,7 @@ def EggDomain.toJson (e : EggDomain) : String := surroundQuotes e.lo_open ++ "]" -structure EggRequest where - domains : List (String × EggDomain) -- Tuples are lists of two elements. - target : EggMinimization - -/-- TODO +/-- JSONify an `EggRequest`. NOTE: Tuples are lists of two elements. -/ def EggRequest.toJson (e : EggRequest) : String := "{" ++ @@ -72,10 +76,6 @@ def EggRequest.toJson (e : EggRequest) : String := surroundQuotes "target" ++ " : " ++ (e.target.toJson) ++ "}" -inductive EggRewriteDirection where - | Forward - | Backward - deriving Inhabited, DecidableEq def EggRewriteDirection.toString : EggRewriteDirection → String | Forward => "fwd" @@ -84,14 +84,6 @@ def EggRewriteDirection.toString : EggRewriteDirection → String instance : ToString EggRewriteDirection where toString := EggRewriteDirection.toString -structure EggRewrite where - rewriteName : String - direction : EggRewriteDirection - location : String - subexprFrom : String - subexprTo : String - expectedTerm : String - def EggRewrite.toString (e : EggRewrite) : String := "{" ++ surroundQuotes "rewrite_name" ++ " : " ++ surroundQuotes e.rewriteName ++ ", " ++ @@ -105,24 +97,27 @@ def EggRewrite.toString (e : EggRewrite) : String := instance : ToString EggRewrite where toString := EggRewrite.toString +/-- Call the `egg` sub-process by runnning the executable `egg-pre-dcp/utils/egg-pre-dcp`, generated +by running `lake build EggPreDCP`. The input is passed via standard input. -/ def runEggRequestRaw (requestJson : String) : MetaM String := do - let eggProcess ← IO.Process.spawn { - cmd := "egg-pre-dcp/utils/egg-pre-dcp", - stdout := IO.Process.Stdio.piped, - stdin := IO.Process.Stdio.piped, - stderr := IO.Process.Stdio.null - } + let eggProcess ← IO.Process.spawn { + cmd := "egg-pre-dcp/utils/egg-pre-dcp", + stdout := IO.Process.Stdio.piped, + stdin := IO.Process.Stdio.piped, + stderr := IO.Process.Stdio.null + } - let (stdin, eggProcess) ← eggProcess.takeStdin - stdin.putStr requestJson + let (stdin, eggProcess) ← eggProcess.takeStdin + stdin.putStr requestJson - let stdout ← IO.asTask eggProcess.stdout.readToEnd Task.Priority.dedicated - let stdout : String ← IO.ofExcept stdout.get - let exitCode ← eggProcess.wait - dbg_trace s!"Egg exit code: {exitCode}" + let stdout ← IO.asTask eggProcess.stdout.readToEnd Task.Priority.dedicated + let stdout : String ← IO.ofExcept stdout.get + let exitCode ← eggProcess.wait + dbg_trace s!"Egg exit code: {exitCode}" - return stdout + return stdout +/-- Read `egg`'s output and trun it into an array of `EggRewrite`s. -/ def parseEggResponse (responseString : String) : MetaM (Array EggRewrite) := do dbg_trace s!"Egg response: {responseString}" let outJson : Json ← match Json.parse responseString with @@ -159,7 +154,7 @@ def parseEggResponse (responseString : String) : MetaM (Array EggRewrite) := do return res -/-- -/ +/-- Run request to `egg` and parse the output to get an array of rewrites, if successful. -/ def runEggRequest (request : EggRequest) : MetaM (Array EggRewrite) := dbg_trace s!"Running egg request: {request.toJson}" runEggRequestRaw request.toJson >>= parseEggResponse diff --git a/CvxLean/Tactic/PreDCP/Egg/Sexp.lean b/CvxLean/Tactic/PreDCP/Egg/Sexp.lean index 19bbc5b4..3cb18c29 100644 --- a/CvxLean/Tactic/PreDCP/Egg/Sexp.lean +++ b/CvxLean/Tactic/PreDCP/Egg/Sexp.lean @@ -1,8 +1,14 @@ import Lean -open Lean +/-! +# S-expressions + +This file defines S-expressions, and provides a parser for them. + +Taken from . +-/ --- Taken from https://github.com/opencompl/egg-tactic-code +open Lean inductive Sexp | atom : String → Sexp @@ -109,16 +115,17 @@ def SexpM.pop: SexpM SexpTok := do return x -- Remove elements from the stack of tokens `List SexpToken` till we find a `SexpToken.opening`. --- When we do, return (1) the position of the open paren, (2) the list of SexpTokens left on the stack, and (3) the list of Sexps +-- When we do, return (1) the position of the open paren, (2) the list of SexpTokens left on the +-- stack, and (3) the list of Sexps -- Until then, accumulate the `SexpToken.sexp`s into `sexps`. def stackPopTillOpen (stk : List SexpTok) (sexps : List Sexp := []) : - Option (String.Pos × (List SexpTok) × (List Sexp)) := + Option (String.Pos × (List SexpTok) × (List Sexp)) := match stk with | [] => .none | SexpTok.opening openPos :: rest => (.some (openPos, rest, sexps)) | SexpTok.sexp s :: rest => stackPopTillOpen rest (s :: sexps) -/-- Collapse the current stack till the last ( into a single `Sexp.list`. -/ +/-- Collapse the current stack till the last `(` into a single `Sexp.list`. -/ def SexpM.matchClosingParen: SexpM Unit := do let state ← get match stackPopTillOpen state.stack with @@ -163,15 +170,14 @@ partial def SexpM.parse : SexpM Unit := do SexpError.unmatchedOpenParen ({ s := state.it.s, i := openPos }) | none => return () -/-- Parse a list of (possibly empty) sexps. -/ +/-- Parse a list of (possibly empty) S-expressions. -/ def parseSexpList (s: String): Except SexpError (List Sexp) := let initState := SexpState.fromString s match EStateM.run SexpM.parse initState with | .ok () state => .ok state.sexps.reverse | .error e _ => .error e -/-- Parse a single s-expression, and error if found no sexp or multiple sexps. --/ +/-- Parse a single s-expression, and error if found no sexp or multiple S-expressions. -/ def parseSingleSexp (s: String) : Except SexpError Sexp := do match (← parseSexpList s) with | [x] => .ok x diff --git a/CvxLean/Tactic/PreDCP/Egg/ToExpr.lean b/CvxLean/Tactic/PreDCP/Egg/ToExpr.lean index f58dbf9a..443e18a9 100644 --- a/CvxLean/Tactic/PreDCP/Egg/ToExpr.lean +++ b/CvxLean/Tactic/PreDCP/Egg/ToExpr.lean @@ -2,71 +2,91 @@ import CvxLean.Lib.Math.Data.Real import CvxLean.Meta.Minimization import CvxLean.Meta.Util.Error import CvxLean.Tactic.PreDCP.Egg.Sexp +import CvxLean.Tactic.PreDCP.Egg.EggTypes import CvxLean.Tactic.DCP.Tree -namespace CvxLean +/-! +# Strings as S-expressions to `Expr` + +This file is, essentially, the reverse of `CvxLean/Tactic/PreDCP/Egg/FromMinimization`. Once `egg` +has given us a list of steps, we need to turn them into expressions to re-construct a proof. We do +that here. This is used by the `pre_dcp` tactic. +-/ -section Egg.ToExpr +namespace CvxLean open Lean Meta -/-- -/ -partial def Sexpr.toEggTree : Sexp → MetaM (Tree String String) +/-- Parse an S-expression into an `EggTree`. -/ +partial def Sexpr.toEggTree : Sexp → MetaM EggTree | .atom a => return Tree.leaf a | .list l => do if l.length == 0 then - throwError "Sexp to Tree conversion error: unexpected empty list." + throwPreDCPError "`Sexp` to `Tree` conversion error, unexpected empty list." else match l.head! with | .atom op => - let children ← l.tail.mapM toEggTree - return CvxLean.Tree.node op (Array.mk children) - | .list _ => throwError "Sexp to Tree conversion error: unexpected list as operator." + let children ← l.tail.mapM toEggTree + return CvxLean.Tree.node op (Array.mk children) + | .list _ => + throwPreDCPError "`Sexp` to `Tree` conversion error, unexpected list as operator." -def EggString.toEggTree (s : String) : MetaM (Tree String String) := do +/-- Parse an string into an S-expression and, if successful, parse it into an `EggTree`. -/ +def EggString.toEggTree (s : String) : MetaM EggTree := do match parseSingleSexp s with | Except.ok sexpr => Sexpr.toEggTree sexpr | Except.error e => throwError s!"{e}" -/-- -/ -partial def EggTree.toExpr (vars params : List String) : Tree String String → MetaM Expr +/-- After turning an S-expression into an `EggTree` (AST), we traverse it to turn it into a Lean +`Expr`. This mostly involves translating the constructors (as strings) into their corresponding +Lean constructors. + +There are some special cases: +* Variables are translated into free variables and parameters are translated into constants. +* Constraints are composed using their tags, which are embedded into the tree. +* Numerical values are parsed into `OfScientific`s. +* In the expected sub-expression, there might be a `"?"`. This translates to a free varaible called + `subexpr` (no optimization variable may be named `subexpr`). It is `pre_dcp`'s responsibility to + bind this free variable appropriately (as well as the rest of free variables). -/ +partial def EggTree.toExpr (vars params : List String) : EggTree → MetaM Expr -- Pi. | Tree.leaf "pi" => do return mkConst ``Real.pi -- Special case: hole. Used to apply `congrAgr` in `pre_dcp`. | Tree.leaf "?" => do - if "subexpr" ∈ vars then - throwPreDCPError "\"subexpr\" is not a valid variable name." - return mkFVar (FVarId.mk (Name.mkSimple "subexpr")) + if "subexpr" ∈ vars then + throwPreDCPError "\"subexpr\" is not a valid variable name." + return mkFVar (FVarId.mk (Name.mkSimple "subexpr")) -- Numbers. | Tree.leaf s => - match Json.Parser.num s.mkIterator with - | Parsec.ParseResult.success _ res => do - -- NOTE: not ideal, but `norm_num` should get us to where we want. - let divisionRingToOfScientific := - mkApp2 (mkConst ``DivisionRing.toOfScientific [levelZero]) - (mkConst ``Real) - (mkConst ``Real.instDivisionRingReal) - let realOfScientific := - mkApp2 (mkConst ``OfScientific.ofScientific [levelZero]) - (mkConst ``Real) - divisionRingToOfScientific - let num := mkApp3 realOfScientific - (mkNatLit res.mantissa.natAbs) (Lean.toExpr true) (mkNatLit res.exponent) - let expr := - if res.mantissa < 0 then - mkApp3 (mkConst ``Neg.neg [levelZero]) (mkConst ``Real) (mkConst ``Real.instNegReal) num - else num - let simpResult ← - Mathlib.Meta.NormNum.deriveSimp (ctx := ← Simp.Context.mkDefault) (e := expr) - return simpResult.expr - | _ => throwPreDCPError "conversion error, unexpected num {s}." + match Json.Parser.num s.mkIterator with + | Parsec.ParseResult.success _ res => do + -- NOTE: not ideal, but `norm_num` should get us to where we want. + let divisionRingToOfScientific := + mkApp2 (mkConst ``DivisionRing.toOfScientific [levelZero]) + (mkConst ``Real) + (mkConst ``Real.instDivisionRingReal) + let realOfScientific := + mkApp2 (mkConst ``OfScientific.ofScientific [levelZero]) + (mkConst ``Real) + divisionRingToOfScientific + let num := mkApp3 realOfScientific + (mkNatLit res.mantissa.natAbs) (Lean.toExpr true) (mkNatLit res.exponent) + let expr := + if res.mantissa < 0 then + mkApp3 + (mkConst ``Neg.neg [levelZero]) (mkConst ``Real) (mkConst ``Real.instNegReal) num + else num + let simpResult ← + Mathlib.Meta.NormNum.deriveSimp (ctx := ← Simp.Context.mkDefault) (e := expr) + return simpResult.expr + | _ => throwPreDCPError "conversion error, unexpected num {s}." -- Variables. | Tree.node "var" #[Tree.leaf s] => - if s ∈ vars then - return mkFVar (FVarId.mk (Name.mkSimple s)) - else - throwPreDCPError "conversion error, unexpected var {s}." + if s ∈ vars then + return mkFVar (FVarId.mk (Name.mkSimple s)) + else + throwPreDCPError "conversion error, unexpected var {s}." -- Parameters. | Tree.node "param" #[Tree.leaf s] => if s ∈ params then @@ -75,46 +95,46 @@ partial def EggTree.toExpr (vars params : List String) : Tree String String → return mkConst (Name.mkSimple s) -- Equality. | Tree.node "eq" #[t1, t2] => do - let t1 ← toExpr vars params t1 - let t2 ← toExpr vars params t2 - return mkAppN (mkConst ``Eq [levelOne]) #[(mkConst `Real), t1, t2] + let t1 ← toExpr vars params t1 + let t2 ← toExpr vars params t2 + return mkAppN (mkConst ``Eq [levelOne]) #[(mkConst `Real), t1, t2] -- Less than or equal to. | Tree.node "le" #[t1, t2] => do - let t1 ← toExpr vars params t1 - let t2 ← toExpr vars params t2 - return mkAppN - (mkConst ``LE.le [levelZero]) - #[(mkConst `Real), (mkConst `Real.instLEReal), t1, t2] + let t1 ← toExpr vars params t1 + let t2 ← toExpr vars params t2 + return mkAppN + (mkConst ``LE.le [levelZero]) + #[(mkConst `Real), (mkConst `Real.instLEReal), t1, t2] -- Negation. | Tree.node "neg" #[t] => do - let t ← toExpr vars params t - return mkAppN - (mkConst ``Neg.neg [levelZero]) - #[(mkConst ``Real), (mkConst ``Real.instNegReal), t] + let t ← toExpr vars params t + return mkAppN + (mkConst ``Neg.neg [levelZero]) + #[(mkConst ``Real), (mkConst ``Real.instNegReal), t] -- Inverse. | Tree.node "inv" #[t] => do - let t ← toExpr vars params t - return mkAppN - (mkConst ``Inv.inv [levelZero]) - #[(mkConst ``Real), (mkConst ``Real.instInvReal), t] + let t ← toExpr vars params t + return mkAppN + (mkConst ``Inv.inv [levelZero]) + #[(mkConst ``Real), (mkConst ``Real.instInvReal), t] -- Absolute value. | Tree.node "abs" #[t] => do - let t ← toExpr vars params t - return mkAppN - (mkConst ``abs [levelZero]) - #[(mkConst ``Real), (mkConst ``Real.lattice), (mkConst ``Real.instAddGroupReal), t] + let t ← toExpr vars params t + return mkAppN + (mkConst ``abs [levelZero]) + #[(mkConst ``Real), (mkConst ``Real.lattice), (mkConst ``Real.instAddGroupReal), t] -- Square root. | Tree.node "sqrt" #[t] => do - let t ← toExpr vars params t - return mkAppN (mkConst ``Real.sqrt) #[t] + let t ← toExpr vars params t + return mkAppN (mkConst ``Real.sqrt) #[t] -- Log. | Tree.node "log" #[t] => do - let t ← toExpr vars params t - return mkAppN (mkConst ``Real.log) #[t] + let t ← toExpr vars params t + return mkAppN (mkConst ``Real.log) #[t] -- Exp. | Tree.node "exp" #[t] => do - let t ← toExpr vars params t - return mkAppN (mkConst ``Real.exp) #[t] + let t ← toExpr vars params t + return mkAppN (mkConst ``Real.exp) #[t] -- XExp. | Tree.node "xexp" #[t] => EggTree.toExpr vars params (Tree.node "mul" #[t, Tree.node "exp" #[t]]) @@ -123,54 +143,54 @@ partial def EggTree.toExpr (vars params : List String) : Tree String String → EggTree.toExpr vars params (Tree.node "neg" #[Tree.node "mul" #[t, Tree.node "log" #[t]]]) -- Min. | Tree.node "min" #[t1, t2] => do - let t1 ← toExpr vars params t1 - let t2 ← toExpr vars params t2 - return mkAppN - (mkConst ``Min.min [levelZero]) - #[(mkConst ``Real), (mkConst ``Real.instMinReal), t1, t2] + let t1 ← toExpr vars params t1 + let t2 ← toExpr vars params t2 + return mkAppN + (mkConst ``Min.min [levelZero]) + #[(mkConst ``Real), (mkConst ``Real.instMinReal), t1, t2] -- Max. | Tree.node "max" #[t1, t2] => do - let t1 ← toExpr vars params t1 - let t2 ← toExpr vars params t2 - return mkAppN - (mkConst ``Max.max [levelZero]) - #[(mkConst ``Real), (mkConst ``Real.instMaxReal), t1, t2] + let t1 ← toExpr vars params t1 + let t2 ← toExpr vars params t2 + return mkAppN + (mkConst ``Max.max [levelZero]) + #[(mkConst ``Real), (mkConst ``Real.instMaxReal), t1, t2] -- Addition. | Tree.node "add" #[t1, t2] => do - let t1 ← toExpr vars params t1 - let t2 ← toExpr vars params t2 - return mkRealHBinAppExpr ``HAdd.hAdd ``instHAdd 1 ``Real.instAddReal t1 t2 + let t1 ← toExpr vars params t1 + let t2 ← toExpr vars params t2 + return mkRealHBinAppExpr ``HAdd.hAdd ``instHAdd 1 ``Real.instAddReal t1 t2 -- Subtraction. | Tree.node "sub" #[t1, t2] => do - let t1 ← toExpr vars params t1 - let t2 ← toExpr vars params t2 - return mkRealHBinAppExpr ``HSub.hSub ``instHSub 1 ``Real.instSubReal t1 t2 + let t1 ← toExpr vars params t1 + let t2 ← toExpr vars params t2 + return mkRealHBinAppExpr ``HSub.hSub ``instHSub 1 ``Real.instSubReal t1 t2 -- Multiplication. | Tree.node "mul" #[t1, t2] => do - let t1 ← toExpr vars params t1 - let t2 ← toExpr vars params t2 - return mkRealHBinAppExpr ``HMul.hMul ``instHMul 1 ``Real.instMulReal t1 t2 + let t1 ← toExpr vars params t1 + let t2 ← toExpr vars params t2 + return mkRealHBinAppExpr ``HMul.hMul ``instHMul 1 ``Real.instMulReal t1 t2 -- Division. | Tree.node "div" #[t1, t2] => do - let t1 ← toExpr vars params t1 - let t2 ← toExpr vars params t2 - return mkRealHBinAppExpr ``HDiv.hDiv ``instHDiv 1 ``Real.instDivReal t1 t2 + let t1 ← toExpr vars params t1 + let t2 ← toExpr vars params t2 + return mkRealHBinAppExpr ``HDiv.hDiv ``instHDiv 1 ``Real.instDivReal t1 t2 -- Pow. | Tree.node "pow" #[t1, t2] => do - let t1 ← toExpr vars params t1 - let t2 ← toExpr vars params t2 - return mkRealHBinAppExpr ``HPow.hPow ``instHPow 2 ``Real.instPowReal t1 t2 + let t1 ← toExpr vars params t1 + let t2 ← toExpr vars params t2 + return mkRealHBinAppExpr ``HPow.hPow ``instHPow 2 ``Real.instPowReal t1 t2 -- Quad over Lin. | Tree.node "qol" #[t1, t2] => - EggTree.toExpr vars params (Tree.node "div" #[Tree.node "pow" #[t1, Tree.leaf "2"], t2]) + EggTree.toExpr vars params (Tree.node "div" #[Tree.node "pow" #[t1, Tree.leaf "2"], t2]) -- Geo mean. | Tree.node "geo" #[t1, t2] => - EggTree.toExpr vars params (Tree.node "sqrt" #[Tree.node "mul" #[t1, t2]]) + EggTree.toExpr vars params (Tree.node "sqrt" #[Tree.node "mul" #[t1, t2]]) -- Log sum exp. | Tree.node "lse" #[t1, t2] => - EggTree.toExpr vars params (Tree.node "log" #[Tree.node "add" #[ - Tree.node "exp" #[t1], - Tree.node "exp" #[t2]]]) + EggTree.toExpr vars params (Tree.node "log" #[Tree.node "add" #[ + Tree.node "exp" #[t1], + Tree.node "exp" #[t2]]]) -- Norm2. | Tree.node "norm2" #[t1, t2] => EggTree.toExpr vars params (Tree.node "sqrt" #[Tree.node "add" #[ @@ -178,11 +198,11 @@ partial def EggTree.toExpr (vars params : List String) : Tree String String → Tree.node "pow" #[t2, Tree.leaf "2"]]]) -- Constr. | Tree.node "constr" #[Tree.leaf s, t] => do - let t ← toExpr vars params t - return Meta.mkLabel (Name.mkSimple s) t + let t ← toExpr vars params t + return Meta.mkLabel (Name.mkSimple s) t -- Error. | Tree.node op children => - throwPreDCPError "conversion error, unexpected op {op} with {children.size} children." + throwPreDCPError "conversion error, unexpected op {op} with {children.size} children." where mkRealHBinAppExpr (opName instHName : Name) (nTyArgs : Nat) (instName : Name) (e1 e2 : Expr) : Expr := @@ -193,10 +213,9 @@ where (mkConst opName [levelZero, levelZero, levelZero]) #[R, R, R, inst, e1, e2] -/-- -/ +/-- First turn a string (as a S-expression) into an `EggTree`, then call `EggTree.toExpr` to get +an expression. -/ def EggString.toExpr (vars params : List Name) (s : String) : MetaM Expr := EggString.toEggTree s >>= EggTree.toExpr (vars.map toString) (params.map toString) -end Egg.ToExpr - end CvxLean diff --git a/CvxLean/Tactic/PreDCP/UncheckedDCP.lean b/CvxLean/Tactic/PreDCP/Egg/UncheckedDCP.lean similarity index 81% rename from CvxLean/Tactic/PreDCP/UncheckedDCP.lean rename to CvxLean/Tactic/PreDCP/Egg/UncheckedDCP.lean index 65705102..80174354 100644 --- a/CvxLean/Tactic/PreDCP/UncheckedDCP.lean +++ b/CvxLean/Tactic/PreDCP/Egg/UncheckedDCP.lean @@ -2,6 +2,14 @@ import CvxLean.Syntax.OptimizationParam import CvxLean.Meta.Util.Expr import CvxLean.Meta.Util.Error import CvxLean.Tactic.DCP.DCP +import CvxLean.Tactic.PreDCP.Egg.EggTypes + +/-! +This is the first step when applying the `pre_dcp` tactic that allows us to get the AST of the +minimization problem re-using the machinery from DCP. This file works together with +`CvxLean/Tactic/PreDCP/Egg/FromMinimiation.lean`, which takes the AST returned by +`uncheckedTreeFromMinimizationExpr` and turns it into an S-expression. +-/ namespace CvxLean @@ -9,8 +17,11 @@ open Lean Meta namespace UncheckedDCP -partial def mkUncheckedTree (originalVarsDecls paramsDecls : Array LocalDecl) (oc : OC (Option (Name × Expr))) : - MetaM (OC (Option (String × Tree String String))) := do +/-- We re-use the machinery from DCP to easily turn expressions into ASTs represented by a tree +of strings. The input is simply the underlying expressions of the optimization problem (as `Expr`). +As usual, optimization variables and parameters need to be added to the local context. -/ +partial def mkUncheckedTree (originalVarsDecls paramsDecls : Array LocalDecl) + (oc : OC (Option (Name × Expr))) : MetaM (OC (Option (String × Tree String String))) := do withExistingLocalDecls (originalVarsDecls ++ paramsDecls).toList do let varsIds := originalVarsDecls.map fun decl => decl.fvarId let paramsIds := paramsDecls.map fun decl => decl.fvarId @@ -20,6 +31,7 @@ partial def mkUncheckedTree (originalVarsDecls paramsDecls : Array LocalDecl) (o return some (n.toString, uncheckedAtomTree) | none => return none) where + /-- Like DCP's `findAtoms` but ignoring curvature checks. The result is a tree of strings. -/ findUncheckedAtoms (e : Expr) (vars params : Array FVarId) : MetaM (Tree String String) := do -- Numerical constants need to be handled separately, as they are ignored by `DCP`. let optParamsIds := (← getAllOptimizationParams).map FVarId.mk @@ -55,7 +67,7 @@ where if e.isFVar && vars.contains e.fvarId! then let n := (originalVarsDecls.find? (fun decl => decl.fvarId == e.fvarId!)).get!.userName return Tree.node "var" #[Tree.leaf (toString n)] - if e.isFVar ∧ params.contains e.fvarId! then + if e.isFVar && params.contains e.fvarId! then let n := (paramsDecls.find? (fun decl => decl.fvarId == e.fvarId!)).get!.userName return Tree.node "param" #[Tree.leaf (toString n)] if e.isConst && (← isOptimizationParam e.constName) then @@ -78,8 +90,9 @@ where return res - processUncheckedAtom (vars params : Array FVarId) (atom : GraphAtomData) - (args : Array Expr) : MetaM (Tree String String) := do + /-- Apply `findUncheckedAtoms` to children and build a tree. -/ + processUncheckedAtom (vars params : Array FVarId) (atom : GraphAtomData) (args : Array Expr) : + MetaM EggTree := do let mut childTrees := #[] for i in [:args.size] do let arg := args[i]! @@ -88,9 +101,9 @@ where return Tree.node (toString atom.id) childTrees -/-- -/ -def uncheckedTreeFromMinimizationExpr (goalExprs : MinimizationExpr) : - MetaM (OC (String × Tree String String)) := do +/-- Wrapper of `mkUncheckedTree` that takes a minimization expression, decomposes it by its +components and returns the corresponding AST as a tree of strings. -/ +def uncheckedTreeFromMinimizationExpr (goalExprs : MinimizationExpr) : MetaM EggOCTree := do let (objFun, constraints, originalVarsDecls) ← withLambdaBody goalExprs.constraints fun p constraints => do let pr := (← Meta.mkProjections goalExprs.domain p).toArray diff --git a/CvxLean/Tactic/PreDCP/PreDCP.lean b/CvxLean/Tactic/PreDCP/PreDCP.lean index 02ae90a1..713abdcf 100644 --- a/CvxLean/Tactic/PreDCP/PreDCP.lean +++ b/CvxLean/Tactic/PreDCP/PreDCP.lean @@ -181,36 +181,37 @@ def preDCPBuilder : EquivalenceBuilder Unit := fun eqvExpr g => g.withContext do let paramsDomainsOC : OC (Option (Name × Expr)) := OC.mk none (paramsDomainsExpr.map some) let potentialParamDomains ← UncheckedDCP.mkUncheckedTree #[] paramsDecls.toArray paramsDomainsOC let mut paramDomains := #[] - for c in potentialParamDomains.constr do + for c in potentialParamDomains.constrs do if let some (h, t) := c then - if let some (_, n, d) := ExtendedEggTree.processDomainExprTree h t paramsStr then + if let some ⟨_, n, d⟩ := EggOCTreeExtended.processDomainExprTree h t paramsStr then paramDomains := paramDomains.push (n, d) -- Get goal as tree and create tags map. - let (gStr, domainConstrs) ← ExtendedEggTree.fromMinimization lhs varsStr + let (gStr, domainConstrs) ← EggOCTreeExtended.fromMinimization lhs varsStr let probSize := (gStr.map fun (_, t) => t.size).fold 0 Nat.add let mut tagsMap := HashMap.empty tagsMap := tagsMap.insert "objFun" 0 let mut idx := 1 - for (h, _) in gStr.constr do + for (h, _) in gStr.constrs do tagsMap := tagsMap.insert h idx idx := idx + 1 -- Handle domain constraints. - let varDomainConstrs := domainConstrs.map (fun (_, v, d) => (v, d)) - let constrsToIgnore := domainConstrs.map (fun (h, _, _) => h) + let varDomainConstrs := domainConstrs.map (fun ⟨_, v, d⟩ => (v, d)) + let constrsToIgnore := domainConstrs.map (fun ⟨h, _, _⟩ => h) -- Remove domain constraints before sending it to egg. - let gStr := { gStr with constr := gStr.constr.filter (fun (h, _) => !constrsToIgnore.contains h) } + let gStr := { gStr with + constrs := gStr.constrs.filter (fun (h, _) => !constrsToIgnore.contains h) } - -- Prepare egg request. - let eggMinimization := EggMinimization.ofOCTree gStr + -- Prepare `egg` request. + let eggMinimization := EggMinimization.ofEggOCTree gStr let eggRequest : EggRequest := { domains := (varDomainConstrs ++ paramDomains).data, target := eggMinimization } try - -- Call egg (time it for evaluation). + -- Call `egg` (time it for evaluation). let bef ← BaseIO.toIO IO.monoMsNow let steps ← runEggRequest eggRequest let aft ← BaseIO.toIO IO.monoMsNow diff --git a/CvxLean/Tactic/PreDCP/RuleToTacticCmd.lean b/CvxLean/Tactic/PreDCP/RuleToTacticCmd.lean index df3be186..8a5c2a61 100644 --- a/CvxLean/Tactic/PreDCP/RuleToTacticCmd.lean +++ b/CvxLean/Tactic/PreDCP/RuleToTacticCmd.lean @@ -1,31 +1,41 @@ import Lean import CvxLean.Tactic.PreDCP.RuleToTacticExt +/-! +# Egg rewrite rule to tactic (command) + +This files defines the commands to associate rule names to Lean tactics. There are two commands: +* `register_rule_to_tactic`, for if-and-only-if or real equality rewrites. +* `register_objFun_rule_to_tactic`, for problem-level rewrites (currently these are pre-compositions + applied to the objective function). They do the same but set the `mapObjFun` flag to `true` in the + environment extension so that `pre_dcp` knows that it should be applied directly. +-/ + namespace CvxLean open Lean Parser syntax (name := registerRuleToTactic) - "register_rule_to_tactic " str " ; " str " => " str " := " tactic ";" : command + "register_rule_to_tactic " str ";" str " => " str " := " tactic ";" : command syntax (name := registerRuleToTacticBidirectional) - "register_rule_to_tactic " str " ; " str " <=> " str " := " tactic ";" : command + "register_rule_to_tactic " str ";" str " <=> " str " := " tactic ";" : command macro_rules - | `(register_rule_to_tactic $rwName ; $rwTarget <=> $rwGoal := $tac;) => + | `(register_rule_to_tactic $rwName; $rwTarget <=> $rwGoal := $tac;) => if let some rwNameStr := Syntax.isStrLit? rwName then let rwNameRev := Syntax.mkStrLit (rwNameStr ++ "-rev") - `(register_rule_to_tactic $rwName ; $rwTarget => $rwGoal := $tac; - register_rule_to_tactic $rwNameRev ; $rwGoal => $rwTarget := $tac;) + `(register_rule_to_tactic $rwName; $rwTarget => $rwGoal := $tac; + register_rule_to_tactic $rwNameRev; $rwGoal => $rwTarget := $tac;) else `(throwError "`register_rule_to_tactic` error: expected string.") syntax (name := registerObjFunRuleToTactic) - "register_objFun_rule_to_tactic " str " ; " str " => " str " := " tactic ";" : command + "register_objFun_rule_to_tactic " str ";" str " => " str " := " tactic ";" : command open Lean.Elab Lean.Elab.Command @[command_elab registerRuleToTactic] def elabRuleToTactic : CommandElab -| `(register_rule_to_tactic $rwName ; $rwTarget => $rwGoal := $tac;) => do +| `(register_rule_to_tactic $rwName; $rwTarget => $rwGoal := $tac;) => do let rwNameStr := rwName.getString -- NOTE: We ignore this for now. let _rwTargetStr := rwTarget.getString @@ -34,7 +44,7 @@ open Lean.Elab Lean.Elab.Command | _ => throwUnsupportedSyntax @[command_elab registerObjFunRuleToTactic] def elabRegisterObjFunRuleToTactic : CommandElab -| `(register_objFun_rule_to_tactic $rwName ; $rwTarget => $rwGoal := $tac;) => do +| `(register_objFun_rule_to_tactic $rwName; $rwTarget => $rwGoal := $tac;) => do let rwNameStr := rwName.getString -- NOTE: We ignore this for now. let _rwTargetStr := rwTarget.getString diff --git a/CvxLean/Tactic/PreDCP/RuleToTacticExt.lean b/CvxLean/Tactic/PreDCP/RuleToTacticExt.lean index b9f8ef8b..b11d0258 100644 --- a/CvxLean/Tactic/PreDCP/RuleToTacticExt.lean +++ b/CvxLean/Tactic/PreDCP/RuleToTacticExt.lean @@ -1,7 +1,7 @@ import Lean /-! -# Egg rewrite rule to tactic (extension) +# Egg rewrite rule to tactic (environment extension) This files defines the persistent environment extension used to store translation of `egg` rewrite rule names to Lean tactics. diff --git a/CvxLean/Tactic/PreDCP/RuleToTacticLibrary.lean b/CvxLean/Tactic/PreDCP/RuleToTacticLibrary.lean index 74c8af2b..39d1ea14 100644 --- a/CvxLean/Tactic/PreDCP/RuleToTacticLibrary.lean +++ b/CvxLean/Tactic/PreDCP/RuleToTacticLibrary.lean @@ -3,14 +3,21 @@ import CvxLean.Tactic.PreDCP.RuleToTacticCmd import CvxLean.Tactic.Arith.Arith /-! +# Egg rewrite rule to tactic (library) +This files associates a tactic to each rewrite name. Note that we do not use `rw` or `simp`, because +the `pre_dcp` tactic will set-up the goal so that it corresponds exactly to the lemma. + +We split the rewrites into problem rewrites, if-and-only-if rewrites, and real equality rewrites. -/ namespace CvxLean set_option linter.unreachableTactic false -/- Objective function rules. -/ +section ProblemRewrites + +/-! Objective function rules (problem-level rewrites). -/ register_objFun_rule_to_tactic "map_objFun_log"; "(prob (objFun ?a) ?cs)" => "(prob (objFun (log ?a)) ?cs)" := @@ -20,22 +27,26 @@ register_objFun_rule_to_tactic "map_objFun_sq"; "(prob (objFun ?a) ?cs)" => "(prob (objFun (pow ?a 2)) ?cs)" := apply Minimization.Equivalence.map_objFun_sq (by positivity!); +end ProblemRewrites + + +section IfAndOnlyIfRewrites -/- Equality rules. -/ +/-! Equality if-and-only-if rules. -/ -register_rule_to_tactic "log_eq_log" ; "(eq ?a ?b)" => "(eq (log ?a) (log ?b))" := +register_rule_to_tactic "log_eq_log"; "(eq ?a ?b)" => "(eq (log ?a) (log ?b))" := apply Real.log_eq_log (by positivity!) (by positivity!); -register_rule_to_tactic "log_eq_log-rev" ; "(eq (log ?a) (log ?b))" => "(eq ?a ?b)" := +register_rule_to_tactic "log_eq_log-rev"; "(eq (log ?a) (log ?b))" => "(eq ?a ?b)" := apply (Real.log_eq_log (by positivity!) (by positivity!)).symm; -/- Less than or equal rules. -/ +/-! Less-than-or-equal if-and-only-if rules. -/ -register_rule_to_tactic "le_sub_iff_add_le" ; "(le ?a (sub ?b ?c))" => "(le (add ?a ?c) ?b)" := +register_rule_to_tactic "le_sub_iff_add_le"; "(le ?a (sub ?b ?c))" => "(le (add ?a ?c) ?b)" := apply le_sub_iff_add_le; -register_rule_to_tactic "le_sub_iff_add_le-rev" ; "(le (add ?a ?c) ?b)" => "(le ?a (sub ?b ?c))" := +register_rule_to_tactic "le_sub_iff_add_le-rev"; "(le (add ?a ?c) ?b)" => "(le ?a (sub ?b ?c))" := apply le_sub_iff_add_le.symm; register_rule_to_tactic "sub_le_iff_le_add"; "(le (sub ?a ?c) ?b)" => "(le ?a (add ?b ?c))" := @@ -44,19 +55,19 @@ register_rule_to_tactic "sub_le_iff_le_add"; "(le (sub ?a ?c) ?b)" => "(le ?a (a register_rule_to_tactic "sub_le_iff_le_add-rev"; "(le ?a (add ?b ?c))" => "(le (sub ?a ?c) ?b)" := apply sub_le_iff_le_add.symm; -register_rule_to_tactic "div_le_iff" ; "(le (div ?a ?b) ?c)" => "(le ?a (mul ?b ?c))" := +register_rule_to_tactic "div_le_iff"; "(le (div ?a ?b) ?c)" => "(le ?a (mul ?b ?c))" := apply div_le_iff (by positivity!); -register_rule_to_tactic "div_le_iff-rev" ; "(le (div ?a ?b) ?c)" => "(le ?a (mul ?b ?c))" := +register_rule_to_tactic "div_le_iff-rev"; "(le (div ?a ?b) ?c)" => "(le ?a (mul ?b ?c))" := apply (div_le_iff (by positivity!)).symm; -register_rule_to_tactic "div_le_one-rev" ; "(le ?a ?b)" => "(le (div ?a ?b) 1)" := +register_rule_to_tactic "div_le_one-rev"; "(le ?a ?b)" => "(le (div ?a ?b) 1)" := apply (div_le_one (by positivity!)).symm; -register_rule_to_tactic "log_le_log" ; "(le (log ?a) (log ?b))" => "(le ?a ?b)" := +register_rule_to_tactic "log_le_log"; "(le (log ?a) (log ?b))" => "(le ?a ?b)" := apply Real.log_le_log_iff (by positivity!) (by positivity!); -register_rule_to_tactic "log_le_log-rev" ; "(le ?a ?b)" => "(le (log ?a) (log ?b))" := +register_rule_to_tactic "log_le_log-rev"; "(le ?a ?b)" => "(le (log ?a) (log ?b))" := apply (Real.log_le_log_iff (by positivity!) (by positivity!)).symm; register_rule_to_tactic "pow_two_le_pow_two"; "(le (pow ?a 2) (pow ?b 2))" => "(le ?a ?b)":= @@ -65,38 +76,42 @@ register_rule_to_tactic "pow_two_le_pow_two"; "(le (pow ?a 2) (pow ?b 2))" => "( register_rule_to_tactic "pow_two_le_pow_two-rev"; "(le ?a ?b)" => "(le (pow ?a 2) (pow ?b 2))" := apply (Real.pow_two_le_pow_two (by positivity!) (by positivity!)).symm; +end IfAndOnlyIfRewrites -/- Field rules. -/ -register_rule_to_tactic "neg_neg" ; "(neg (neg ?a))" => "?a" := +section RealEqualityRewrites + +/-! Field rules. -/ + +register_rule_to_tactic "neg_neg"; "(neg (neg ?a))" => "?a" := apply neg_neg (G := ℝ); -register_rule_to_tactic "add_zero" ; "(add ?a 0)" => "?a" := +register_rule_to_tactic "add_zero"; "(add ?a 0)" => "?a" := apply add_zero; -register_rule_to_tactic "add_comm" ; "(add ?a ?b)" => "(add ?b ?a)" := +register_rule_to_tactic "add_comm"; "(add ?a ?b)" => "(add ?b ?a)" := apply add_comm; -register_rule_to_tactic "add_assoc" ; "(add (add ?a ?b) ?c)" => "(add ?a (add ?b ?c))" := +register_rule_to_tactic "add_assoc"; "(add (add ?a ?b) ?c)" => "(add ?a (add ?b ?c))" := apply add_assoc; -register_rule_to_tactic "sub_self" ; "(sub ?a ?a)" => "0" := +register_rule_to_tactic "sub_self"; "(sub ?a ?a)" => "0" := apply sub_self; -register_rule_to_tactic "one_mul" ; "(mul 1 ?a)" => "?a" := +register_rule_to_tactic "one_mul"; "(mul 1 ?a)" => "?a" := apply one_mul; -- Exception, we cannot find the pattern otherwise. -register_rule_to_tactic "one_mul-rev" ; "?a" => "(mul 1 ?a)" := +register_rule_to_tactic "one_mul-rev"; "?a" => "(mul 1 ?a)" := norm_num; -register_rule_to_tactic "mul_zero" ; "(mul ?a 0)" => "0" := +register_rule_to_tactic "mul_zero"; "(mul ?a 0)" => "0" := apply mul_zero (M₀ := ℝ); -register_rule_to_tactic "mul_comm" ; "(mul ?a ?b)" => "(mul ?b ?a)" := +register_rule_to_tactic "mul_comm"; "(mul ?a ?b)" => "(mul ?b ?a)" := apply mul_comm; -register_rule_to_tactic "mul_assoc" ; "(mul (mul ?a ?b) ?c)" => "(mul ?a (mul ?b ?c))" := +register_rule_to_tactic "mul_assoc"; "(mul (mul ?a ?b) ?c)" => "(mul ?a (mul ?b ?c))" := apply mul_assoc; register_rule_to_tactic "sub_eq_add_neg"; "(sub ?a ?b)" => "(add ?a (neg ?b))" := @@ -105,67 +120,67 @@ register_rule_to_tactic "sub_eq_add_neg"; "(sub ?a ?b)" => "(add ?a (neg ?b))" : register_rule_to_tactic "sub_eq_add_neg-rev"; "(add ?a (neg ?b))" => "(sub ?a ?b)" := apply (sub_eq_add_neg (G := ℝ) _ _).symm; -register_rule_to_tactic "add_sub" ; "(add ?a (sub ?b ?c))" => "(sub (add ?a ?b) ?c)" := +register_rule_to_tactic "add_sub"; "(add ?a (sub ?b ?c))" => "(sub (add ?a ?b) ?c)" := apply add_sub; -register_rule_to_tactic "add_sub-rev" ; "(add ?a (sub ?b ?c))" => "(sub (add ?a ?b) ?c)" := +register_rule_to_tactic "add_sub-rev"; "(add ?a (sub ?b ?c))" => "(sub (add ?a ?b) ?c)" := apply (add_sub _ _ _).symm; -register_rule_to_tactic "add_mul" ; "(mul (add ?a ?b) ?c)" => "(add (mul ?a ?c) (mul ?b ?c))" := +register_rule_to_tactic "add_mul"; "(mul (add ?a ?b) ?c)" => "(add (mul ?a ?c) (mul ?b ?c))" := apply add_mul; -register_rule_to_tactic "add_mul-rev" ; "(add (mul ?a ?c) (mul ?b ?c))" => "(mul (add ?a ?b) ?c)" := +register_rule_to_tactic "add_mul-rev"; "(add (mul ?a ?c) (mul ?b ?c))" => "(mul (add ?a ?b) ?c)" := apply (add_mul _ _ _).symm; -register_rule_to_tactic "mul_add" ; "(mul ?a (add ?b ?c))" => "(add (mul ?a ?b) (mul ?a ?c))" := +register_rule_to_tactic "mul_add"; "(mul ?a (add ?b ?c))" => "(add (mul ?a ?b) (mul ?a ?c))" := apply mul_add; -register_rule_to_tactic "mul_add-rev" ; "(add (mul ?a ?b) (mul ?a ?c))" => "(mul ?a (add ?b ?c))" := +register_rule_to_tactic "mul_add-rev"; "(add (mul ?a ?b) (mul ?a ?c))" => "(mul ?a (add ?b ?c))" := apply (mul_add _ _ _).symm; -register_rule_to_tactic "mul_sub" ; "(mul ?a (sub ?b ?c))" => "(sub (mul ?a ?b) (mul ?a ?c))" := +register_rule_to_tactic "mul_sub"; "(mul ?a (sub ?b ?c))" => "(sub (mul ?a ?b) (mul ?a ?c))" := apply mul_sub; -register_rule_to_tactic "mul_sub-rev" ; "(sub (mul ?a ?b) (mul ?a ?c))" => "(mul ?a (sub ?b ?c))" := +register_rule_to_tactic "mul_sub-rev"; "(sub (mul ?a ?b) (mul ?a ?c))" => "(mul ?a (sub ?b ?c))" := apply (mul_sub _ _ _).symm; -register_rule_to_tactic "div_one" ; "(div ?a 1)" => "?a" := +register_rule_to_tactic "div_one"; "(div ?a 1)" => "?a" := apply div_one; -register_rule_to_tactic "add_div" ; "(div (add ?a ?b) ?c)" => "(add (div ?a ?c) (div ?b ?c))" := +register_rule_to_tactic "add_div"; "(div (add ?a ?b) ?c)" => "(add (div ?a ?c) (div ?b ?c))" := apply add_div; -register_rule_to_tactic "add_div-rev" ; "(add (div ?a ?c) (div ?b ?c))" => "(div (add ?a ?b) ?c)" := +register_rule_to_tactic "add_div-rev"; "(add (div ?a ?c) (div ?b ?c))" => "(div (add ?a ?b) ?c)" := apply (add_div _ _ _).symm; -register_rule_to_tactic "mul_div" ; "(mul ?a (div ?b ?c))" => "(div (mul ?a ?b) ?c)" := +register_rule_to_tactic "mul_div"; "(mul ?a (div ?b ?c))" => "(div (mul ?a ?b) ?c)" := apply mul_div (G := ℝ); -register_rule_to_tactic "mul_div-rev" ; "(div (mul ?a ?b) ?c)" => "(mul ?a (div ?b ?c))" := +register_rule_to_tactic "mul_div-rev"; "(div (mul ?a ?b) ?c)" => "(mul ?a (div ?b ?c))" := apply (mul_div (G := ℝ) _ _ _).symm; -register_rule_to_tactic "div_mul" ; "(mul (div ?a ?b) ?c)" => "(div ?a (div ?b ?c))" := +register_rule_to_tactic "div_mul"; "(mul (div ?a ?b) ?c)" => "(div ?a (div ?b ?c))" := apply div_mul; -register_rule_to_tactic "div_mul-rev" ; "(div ?a (div ?b ?c))" => "(mul (div ?a ?b) ?c)" := +register_rule_to_tactic "div_mul-rev"; "(div ?a (div ?b ?c))" => "(mul (div ?a ?b) ?c)" := apply (div_mul _ _ _).symm; -register_rule_to_tactic "mul_div_mul_left" ; "(div (mul ?c ?a) (mul ?c ?b))" => "(div ?a ?b)" := +register_rule_to_tactic "mul_div_mul_left"; "(div (mul ?c ?a) (mul ?c ?b))" => "(div ?a ?b)" := apply mul_div_mul_left (G₀ := ℝ) _ _ (by positivity!); -register_rule_to_tactic "div_div" ; "(div (div ?a ?b) ?c)" => "(div ?a (mul ?b ?c))" := +register_rule_to_tactic "div_div"; "(div (div ?a ?b) ?c)" => "(div ?a (mul ?b ?c))" := apply div_div; -register_rule_to_tactic "div_div-rev" ; "(div ?a (mul ?b ?c))" => "(div (div ?a ?b) ?c)" := +register_rule_to_tactic "div_div-rev"; "(div ?a (mul ?b ?c))" => "(div (div ?a ?b) ?c)" := apply (div_div _ _ _).symm; -register_rule_to_tactic "div_self" ; "(div ?a ?a)" => "1" := +register_rule_to_tactic "div_self"; "(div ?a ?a)" => "1" := apply div_self (G₀ := ℝ) (by positivity!); -register_rule_to_tactic "inv_eq_one_div" ; "(inv ?a)" => "(div 1 ?a)" := +register_rule_to_tactic "inv_eq_one_div"; "(inv ?a)" => "(div 1 ?a)" := apply inv_eq_one_div (G := ℝ); -register_rule_to_tactic "inv_inv" ; "(inv (inv ?a))" => "?a" := +register_rule_to_tactic "inv_inv"; "(inv (inv ?a))" => "?a" := apply inv_inv; @@ -208,26 +223,26 @@ register_rule_to_tactic "pow_sub-rev"; "(div (pow ?a ?b) (pow ?a ?c))" => "(pow apply (Real.rpow_sub (by positivity!) _ _).symm; register_rule_to_tactic - "div_pow_eq_mul_pow_neg" ; "(div ?a (pow ?b ?c))" => "(mul ?a (pow ?b (neg ?c)))" := + "div_pow_eq_mul_pow_neg"; "(div ?a (pow ?b ?c))" => "(mul ?a (pow ?b (neg ?c)))" := apply Real.div_pow_eq_mul_pow_neg (by positivity!); register_rule_to_tactic - "div_pow_eq_mul_pow_neg-rev" ; "(div ?a (pow ?b ?c))" => "(mul ?a (pow ?b (neg ?c)))" := + "div_pow_eq_mul_pow_neg-rev"; "(div ?a (pow ?b ?c))" => "(mul ?a (pow ?b (neg ?c)))" := apply (Real.div_pow_eq_mul_pow_neg (by positivity!)).symm; register_rule_to_tactic "one_div_eq_pow_neg_one"; "(div 1 ?a)" <=> "(pow ?a (neg 1))" := apply Real.one_div_eq_pow_neg_one (by positivity!); -register_rule_to_tactic "sqrt_eq_rpow" ; "(sqrt ?a)" => "(pow ?a 0.5)" := +register_rule_to_tactic "sqrt_eq_rpow"; "(sqrt ?a)" => "(pow ?a 0.5)" := apply Real.sqrt_eq_rpow; -register_rule_to_tactic "sqrt_eq_rpow-rev" ; "(pow ?a 0.5)" => "(sqrt ?a)" := +register_rule_to_tactic "sqrt_eq_rpow-rev"; "(pow ?a 0.5)" => "(sqrt ?a)" := apply (Real.sqrt_eq_rpow _).symm; -register_rule_to_tactic "sqrt_mul" ; "(sqrt (mul ?a ?b))" => "(mul (sqrt ?a) (sqrt ?b))" := +register_rule_to_tactic "sqrt_mul"; "(sqrt (mul ?a ?b))" => "(mul (sqrt ?a) (sqrt ?b))" := apply Real.sqrt_mul (by positivity!); -register_rule_to_tactic "sqrt_mul-rev" ; "(mul (sqrt ?a) (sqrt ?b))" => "(sqrt (mul ?a ?b))" := +register_rule_to_tactic "sqrt_mul-rev"; "(mul (sqrt ?a) (sqrt ?b))" => "(sqrt (mul ?a ?b))" := apply (Real.sqrt_mul (by positivity!) _).symm; register_rule_to_tactic "pow_two"; "(pow ?a 2)" => "(mul ?a ?a)" := @@ -255,68 +270,68 @@ register_rule_to_tactic "inv_eq_pow_neg_one"; "(inv ?a)" => "(pow ?a (neg 1))" : /- Exponential and logarithm rules. -/ -register_rule_to_tactic "exp_add" ; "(exp (add ?a ?b))" => "(mul (exp ?a) (exp ?b))" := +register_rule_to_tactic "exp_add"; "(exp (add ?a ?b))" => "(mul (exp ?a) (exp ?b))" := apply Real.exp_add; -register_rule_to_tactic "exp_add-rev" ; "(mul (exp ?a) (exp ?b))" => "(exp (add ?a ?b))" := +register_rule_to_tactic "exp_add-rev"; "(mul (exp ?a) (exp ?b))" => "(exp (add ?a ?b))" := apply (Real.exp_add _ _).symm; -register_rule_to_tactic "exp_sub" ; "(exp (sub ?a ?b))" => "(div (exp ?a) (exp ?b))" := +register_rule_to_tactic "exp_sub"; "(exp (sub ?a ?b))" => "(div (exp ?a) (exp ?b))" := apply Real.exp_sub; -register_rule_to_tactic "exp_sub-rev" ; "(div (exp ?a) (exp ?b))" => "(exp (sub ?a ?b))" := +register_rule_to_tactic "exp_sub-rev"; "(div (exp ?a) (exp ?b))" => "(exp (sub ?a ?b))" := apply (Real.exp_sub _ _).symm; -register_rule_to_tactic "exp_mul" ; "(exp (mul ?a ?b))" => "(pow (exp ?a) ?b)" := +register_rule_to_tactic "exp_mul"; "(exp (mul ?a ?b))" => "(pow (exp ?a) ?b)" := apply Real.exp_mul; -register_rule_to_tactic "exp_mul-rev" ; "(pow (exp ?a) ?b)" => "(exp (mul ?a ?b))" := +register_rule_to_tactic "exp_mul-rev"; "(pow (exp ?a) ?b)" => "(exp (mul ?a ?b))" := apply (Real.exp_mul _ _).symm; register_rule_to_tactic "exp_neg_eq_one_div"; "(exp (neg ?a))" => "(div 1 (exp ?a))" := apply Real.exp_neg_eq_one_div; -register_rule_to_tactic "exp_neg_eq_one_div-rev" ; "(div 1 (exp ?a))" => "(exp (neg ?a))" := +register_rule_to_tactic "exp_neg_eq_one_div-rev"; "(div 1 (exp ?a))" => "(exp (neg ?a))" := apply (Real.exp_neg_eq_one_div _).symm; -register_rule_to_tactic "log_mul" ; "(log (mul ?a ?b))" => "(add (log ?a) (log ?b))" := +register_rule_to_tactic "log_mul"; "(log (mul ?a ?b))" => "(add (log ?a) (log ?b))" := apply Real.log_mul (by positivity!) (by positivity!); register_rule_to_tactic "log_mul-rev"; "(add (log ?a) (log ?b))" => "(log (mul ?a ?b))" := apply (Real.log_mul (by positivity!) (by positivity!)).symm; -register_rule_to_tactic "log_div" ; "(log (div ?a ?b))" => "(sub (log ?a) (log ?b))" := +register_rule_to_tactic "log_div"; "(log (div ?a ?b))" => "(sub (log ?a) (log ?b))" := apply Real.log_div (by positivity!) (by positivity!); register_rule_to_tactic "log_div-rev"; "(sub (log ?a) (log ?b))" => "(log (div ?a ?b))" := apply (Real.log_div (by positivity!) (by positivity!)).symm; -register_rule_to_tactic "log_pow" ; "(log (pow ?a ?b))" => "(mul ?b (log ?a))" := +register_rule_to_tactic "log_pow"; "(log (pow ?a ?b))" => "(mul ?b (log ?a))" := apply (Real.log_rpow (by positivity!)); -register_rule_to_tactic "log_pow-rev" ; "(mul ?b (log ?a))" => "(log (pow ?a ?b))" := +register_rule_to_tactic "log_pow-rev"; "(mul ?b (log ?a))" => "(log (pow ?a ?b))" := apply (Real.log_rpow (by positivity!) _).symm; -- Special type of rule as it only works if the exponent is a natural number. -register_rule_to_tactic "log_pow_nat" ; "(log (pow ?a ?b))" => "(mul ?b (log ?a))" := +register_rule_to_tactic "log_pow_nat"; "(log (pow ?a ?b))" => "(mul ?b (log ?a))" := (norm_cast; apply Real.log_pow); -register_rule_to_tactic "log_pow_nat-rev" ; "(mul ?b (log ?a))" => "(log (pow ?a ?b))" := +register_rule_to_tactic "log_pow_nat-rev"; "(mul ?b (log ?a))" => "(log (pow ?a ?b))" := (norm_cast; apply (Real.log_pow _ _).symm); -register_rule_to_tactic "exp_log" ; "(exp (log ?a))" => "?a" := +register_rule_to_tactic "exp_log"; "(exp (log ?a))" => "?a" := apply Real.exp_log (by positivity!); -register_rule_to_tactic "log_exp" ; "(log (exp ?a))" => "?a" := +register_rule_to_tactic "log_exp"; "(log (exp ?a))" => "?a" := apply Real.log_exp; /- Absolute value rules. -/ -register_rule_to_tactic "abs_nonneg" ; "(abs ?a)" => "?a" := +register_rule_to_tactic "abs_nonneg"; "(abs ?a)" => "?a" := apply abs_eq_self.mpr (by positivity!); -register_rule_to_tactic "abs_nonpos" ; "(abs ?a)" => "(neg ?a)" := +register_rule_to_tactic "abs_nonpos"; "(abs ?a)" => "(neg ?a)" := apply abs_eq_neg_self.mpr (by linarith); @@ -373,4 +388,6 @@ register_rule_to_tactic "norm2_fold"; "(sqrt (add (pow ?a 2) (pow ?b 2)))" => "( register_rule_to_tactic "norm2_unfold"; "(norm2 ?a ?b)" => "(sqrt (add (pow ?a 2) (pow ?b 2)))" := rfl; +end RealEqualityRewrites + end CvxLean diff --git a/CvxLean/Tactic/Solver/Mosek/Path.lean b/CvxLean/Tactic/Solver/Mosek/Path.lean deleted file mode 100644 index 6bdd0d20..00000000 --- a/CvxLean/Tactic/Solver/Mosek/Path.lean +++ /dev/null @@ -1,3 +0,0 @@ --- Change this to the location of the Mosek binary, e.g. --- /Users/alice/mosek/10.0/tools/platform/osxaarch64/bin -def mosekBinPath := "" \ No newline at end of file diff --git a/CvxLean/Test/BasicTactics/ChangeOfVariables.lean b/CvxLean/Test/BasicTactics/ChangeOfVariables.lean index 4d1c9adf..223ed0a6 100644 --- a/CvxLean/Test/BasicTactics/ChangeOfVariables.lean +++ b/CvxLean/Test/BasicTactics/ChangeOfVariables.lean @@ -2,6 +2,10 @@ import CvxLean.Command.Equivalence import CvxLean.Command.Reduction import CvxLean.Tactic.Basic.ChangeOfVariables +/-! +Testing changes of variables. +-/ + open CvxLean noncomputable section ChangeOfVariablesInstances diff --git a/CvxLean/Test/BasicTactics/Other.lean b/CvxLean/Test/BasicTactics/Other.lean index 3f424065..eafef94e 100644 --- a/CvxLean/Test/BasicTactics/Other.lean +++ b/CvxLean/Test/BasicTactics/Other.lean @@ -2,6 +2,10 @@ import CvxLean.Syntax.Minimization import CvxLean.Tactic.Basic.All import CvxLean.Command.Equivalence +/-! +Testing the tactics from `CvxLean/Tactic/Basic/All.lean`. +-/ + namespace BasicTacticTest noncomputable section diff --git a/CvxLean/Test/DCP/Basic.lean b/CvxLean/Test/DCP/Basic.lean index 3e0f51bc..8068581a 100644 --- a/CvxLean/Test/DCP/Basic.lean +++ b/CvxLean/Test/DCP/Basic.lean @@ -1,5 +1,9 @@ import CvxLean.Tactic.DCP.AtomLibrary.All +/-! +Some tests around the DCP transformation. One can check the outcome of `dcp` for each of them. +-/ + noncomputable section Basic open CvxLean Minimization Real @@ -146,12 +150,12 @@ def c1 : ℝ := 5 def d1 : ℝ := 6 def testWithConstants : Solution <| - optimization (x y : ℝ) - minimize (x) - subject to - c1 : 0 ≤ exp a1 - c2 : exp y ≤ log (exp a1) - c3 : 0 ≤ x := by + optimization (x y : ℝ) + minimize (x) + subject to + c1 : 0 ≤ exp a1 + c2 : exp y ≤ log (exp a1) + c3 : 0 ≤ x := by dcp sorry diff --git a/CvxLean/Test/DCP/Label.lean b/CvxLean/Test/DCP/Label.lean index 0684116e..705e625a 100644 --- a/CvxLean/Test/DCP/Label.lean +++ b/CvxLean/Test/DCP/Label.lean @@ -1,5 +1,13 @@ import CvxLean.Syntax.Label +/-! +Making labels explicitly, and printing them. + +### TODO + +* This does show what is expected because `pp.labels` had to be disabled for `doc-gen` to work. +-/ + def testLabel := {** 0 ** x **} #print testLabel diff --git a/CvxLean/Test/PreDCP/AlmostDGP.lean b/CvxLean/Test/PreDCP/AlmostDGP.lean index f6856cc0..02f160ac 100644 --- a/CvxLean/Test/PreDCP/AlmostDGP.lean +++ b/CvxLean/Test/PreDCP/AlmostDGP.lean @@ -5,6 +5,11 @@ import CvxLean.Command.Util.TimeCmd import CvxLean.Tactic.Basic.ChangeOfVariables import CvxLean.Tactic.PreDCP.PreDCP +/-! +Examples from geometric programming that do not quite follow the rules of DGP, but can be fixed +by rewriting. +-/ + namespace AlmostGP noncomputable section diff --git a/CvxLean/Test/PreDCP/DGP.lean b/CvxLean/Test/PreDCP/DGP.lean index a1485dce..4e936af4 100644 --- a/CvxLean/Test/PreDCP/DGP.lean +++ b/CvxLean/Test/PreDCP/DGP.lean @@ -5,6 +5,10 @@ import CvxLean.Command.Util.TimeCmd import CvxLean.Tactic.Basic.ChangeOfVariables import CvxLean.Tactic.PreDCP.PreDCP +/-! +Examples from geometric programming. +-/ + namespace GP noncomputable section diff --git a/CvxLean/Test/PreDCP/DQCP.lean b/CvxLean/Test/PreDCP/DQCP.lean index f65dc6b7..1fe03dc6 100644 --- a/CvxLean/Test/PreDCP/DQCP.lean +++ b/CvxLean/Test/PreDCP/DQCP.lean @@ -5,6 +5,9 @@ import CvxLean.Command.Util.TimeCmd import CvxLean.Tactic.Basic.ChangeOfVariables import CvxLean.Tactic.PreDCP.PreDCP +/-! +Examples from quasiconvex programming that can be directly rewritten into DCP form by `pre_dcp`. +-/ namespace DQCP diff --git a/CvxLean/Test/PreDCP/MainExample.lean b/CvxLean/Test/PreDCP/MainExample.lean index 02a014a8..0bd5d5bc 100644 --- a/CvxLean/Test/PreDCP/MainExample.lean +++ b/CvxLean/Test/PreDCP/MainExample.lean @@ -3,6 +3,10 @@ import CvxLean.Command.Equivalence import CvxLean.Command.Util.TimeCmd import CvxLean.Tactic.PreDCP.PreDCP +/-! +Example that we often use as a simple motivating example for automated pre-DCP transformations. +-/ + namespace MainExample noncomputable section @@ -34,9 +38,9 @@ solve q -- optimization (x : ℝ) (t.0 : ℝ) (t.1 : ℝ) -- minimize x -- subject to --- _ : posOrthCone (t.1 - t.0) +-- _ : nonnegOrthCone (t.1 - t.0) -- _ : expCone (-x) 1 t.0 --- _ : posOrthCone (x - 1 / 1000) +-- _ : nonnegOrthCone (x - 1 / 1000) -- _ : rotatedSoCone x (1 / 2) ![t.1] #eval q.value -- 0.426303 diff --git a/CvxLean/Test/PreDCP/Quiz.lean b/CvxLean/Test/PreDCP/Quiz.lean index fd86fcb7..99c01a2a 100644 --- a/CvxLean/Test/PreDCP/Quiz.lean +++ b/CvxLean/Test/PreDCP/Quiz.lean @@ -5,6 +5,10 @@ import CvxLean.Command.Util.TimeCmd import CvxLean.Tactic.Basic.ChangeOfVariables import CvxLean.Tactic.PreDCP.PreDCP +/-! +Examples harvested from the DCP quiz . +-/ + namespace Quiz noncomputable section diff --git a/CvxLean/Test/PreDCP/Stanford.lean b/CvxLean/Test/PreDCP/Stanford.lean index 8f947107..01dea00a 100644 --- a/CvxLean/Test/PreDCP/Stanford.lean +++ b/CvxLean/Test/PreDCP/Stanford.lean @@ -5,6 +5,11 @@ import CvxLean.Command.Util.TimeCmd import CvxLean.Tactic.Basic.ChangeOfVariables import CvxLean.Tactic.PreDCP.PreDCP +/-! +Problems from the additional exercises of the Stanford course on convex optimization: +. +-/ + namespace Stanford noncomputable section diff --git a/CvxLean/Test/PreDCP/Unit.lean b/CvxLean/Test/PreDCP/Unit.lean index 311460c5..c8e124fa 100644 --- a/CvxLean/Test/PreDCP/Unit.lean +++ b/CvxLean/Test/PreDCP/Unit.lean @@ -5,6 +5,12 @@ import CvxLean.Command.Util.TimeCmd import CvxLean.Tactic.Basic.ChangeOfVariables import CvxLean.Tactic.PreDCP.PreDCP +/-! +Unit tests for `pre_dcp`. The idea was to force certain rewrites to appear in the sequence. The set +of rewrite rules has been changing, so the name of the tests might not be accurate anymore. Still, +they are useful to check that all these problems can be rewritten into DCP. +-/ + namespace Unit noncomputable section @@ -76,7 +82,8 @@ def leSubIffAddLeConstrRev := subject to h : y + x ≤ x -time_cmd equivalence leSubIffAddLeConstrRevRed/leSubIffAddLeConstrRevAuto : leSubIffAddLeConstrRev := by +time_cmd equivalence leSubIffAddLeConstrRevRed/leSubIffAddLeConstrRevAuto : + leSubIffAddLeConstrRev := by pre_dcp #print leSubIffAddLeConstrRevAuto @@ -101,7 +108,8 @@ def subLeIffLeAddConstrRev := subject to h : x ≤ y + x -time_cmd equivalence subLeIffLeAddConstrRevRed/subLeIffLeAddConstrRevAuto : subLeIffLeAddConstrRev := by +time_cmd equivalence subLeIffLeAddConstrRevRed/subLeIffLeAddConstrRevAuto : + subLeIffLeAddConstrRev := by pre_dcp #print subLeIffLeAddConstrRevAuto diff --git a/CvxLean/Test/Solve/Problems/Exp.lean b/CvxLean/Test/Solve/Problems/Exp.lean index 1f6e4df2..193f55b5 100644 --- a/CvxLean/Test/Solve/Problems/Exp.lean +++ b/CvxLean/Test/Solve/Problems/Exp.lean @@ -1,5 +1,9 @@ import CvxLean.Command.Solve +/-! +Some problems involving the exponential function. +-/ + section Exp open CvxLean Minimization Real diff --git a/CvxLean/Test/Solve/Problems/Huber.py b/CvxLean/Test/Solve/Problems/Huber.py deleted file mode 100644 index 300be3dd..00000000 --- a/CvxLean/Test/Solve/Problems/Huber.py +++ /dev/null @@ -1,30 +0,0 @@ -import cvxpy as cp -import numpy as np - -def hub(x): - v = cp.Variable((1)) - w = cp.Variable((1)) - p = cp.Problem( - cp.Minimize(2 * v + w ** 2), [ - cp.abs(x) <= v + w, - v >= 0, - w <= 1, - ] - ) - p.solve() - return (v.value, w.value, 2 * v.value + w.value ** 2) - -xs = np.linspace(-3.0, 3.0, num=50) -hubs = [hub(x) for x in xs] -vs = np.array([x[0] for x in hubs]) -ws = np.array([x[1] for x in hubs]) -ys = np.array([x[2] for x in hubs]) -zs = np.array([x**2 for x in xs]) - - -import matplotlib.pyplot as plt -plt.plot(xs, ys) -plt.plot(xs, vs) -plt.plot(xs, ws) -plt.show() - diff --git a/CvxLean/Test/Solve/Problems/Linear.lean b/CvxLean/Test/Solve/Problems/Linear.lean index 7d549602..4085d323 100644 --- a/CvxLean/Test/Solve/Problems/Linear.lean +++ b/CvxLean/Test/Solve/Problems/Linear.lean @@ -1,5 +1,9 @@ import CvxLean.Command.Solve +/-! +Some linear programs. +-/ + section Linear open CvxLean Minimization Real @@ -110,7 +114,6 @@ noncomputable def linear5 := h₁ : A5.mulVec x ≤ b5 h₂ : 0 ≤ x -set_option trace.Meta.debug true in solve linear5 #print linear5.conicForm diff --git a/CvxLean/Test/Solve/Problems/Log.lean b/CvxLean/Test/Solve/Problems/Log.lean index d28c2ba9..ba147efa 100644 --- a/CvxLean/Test/Solve/Problems/Log.lean +++ b/CvxLean/Test/Solve/Problems/Log.lean @@ -1,5 +1,9 @@ import CvxLean.Command.Solve +/-! +Some problems involving the logarithm function. +-/ + section Log open CvxLean Minimization Real diff --git a/CvxLean/Test/Solve/Problems/LogDet.lean b/CvxLean/Test/Solve/Problems/LogDet.lean index 21587af2..627d5947 100644 --- a/CvxLean/Test/Solve/Problems/LogDet.lean +++ b/CvxLean/Test/Solve/Problems/LogDet.lean @@ -1,5 +1,9 @@ import CvxLean.Command.Solve +/-! +Testing the log-det atom. +-/ + section LogDet open CvxLean Minimization Real diff --git a/CvxLean/Test/Solve/Problems/SDP.lean b/CvxLean/Test/Solve/Problems/SDP.lean index d3f8f01f..37ba1e74 100644 --- a/CvxLean/Test/Solve/Problems/SDP.lean +++ b/CvxLean/Test/Solve/Problems/SDP.lean @@ -1,5 +1,9 @@ import CvxLean.Command.Solve +/-! +A semidefinite program. +-/ + namespace SDP open CvxLean Minimization Real diff --git a/CvxLean/Test/Solve/Problems/SO.lean b/CvxLean/Test/Solve/Problems/SO.lean index 6bc3c79a..7471e1a1 100644 --- a/CvxLean/Test/Solve/Problems/SO.lean +++ b/CvxLean/Test/Solve/Problems/SO.lean @@ -1,5 +1,9 @@ import CvxLean.Command.Solve +/-! +Some quadratic programs. +-/ + section SO open CvxLean Minimization Real diff --git a/CvxLean/Test/Solve/RealToFloat.lean b/CvxLean/Test/Solve/RealToFloat.lean index 8b7ad384..f6c4cfc0 100644 --- a/CvxLean/Test/Solve/RealToFloat.lean +++ b/CvxLean/Test/Solve/RealToFloat.lean @@ -2,6 +2,10 @@ import CvxLean.Lib.Minimization import CvxLean.Syntax.OptimizationParam import CvxLean.Command.Solve.Float.RealToFloatLibrary +/-! +Real-to-float tests. +-/ + section RealToFloat -- Simple tests. diff --git a/README.md b/README.md index 2333a84b..850b34bc 100644 --- a/README.md +++ b/README.md @@ -68,7 +68,7 @@ reduction red/prob : maximize x + y subject to h : (exp x) * (exp y) ≤ 10 := by - conv_constr => rw [←Real.exp_add] + conv_constr => rw [← Real.exp_add] #print prob -- def prob : Minimization (Real × Real) Real := diff --git a/scripts/check_style.sh b/scripts/check_style.sh deleted file mode 100755 index 7ea838e7..00000000 --- a/scripts/check_style.sh +++ /dev/null @@ -1 +0,0 @@ -python3 ./.lake/packages/mathlib/scripts/lint-style.py $1 | grep -vE "ERR_COP" diff --git a/scripts/style/README.md b/scripts/style/README.md new file mode 100644 index 00000000..1676b48c --- /dev/null +++ b/scripts/style/README.md @@ -0,0 +1,11 @@ +Scripts to check that files are styled according to `mathlib` standards (ignoring copyright). +They are supposed to be run from the root of the repository. + +For a single file: +``` +./scripts/style/check_style_file.sh +``` +For all files in a directory: +``` +./scripts/style/check_style_dir.sh +``` diff --git a/scripts/style/check_style_all.sh b/scripts/style/check_style_all.sh new file mode 100755 index 00000000..4fdaa29c --- /dev/null +++ b/scripts/style/check_style_all.sh @@ -0,0 +1,3 @@ +#!/bin/bash + +./scripts/style/check_style_dir.sh CvxLean diff --git a/scripts/style/check_style_dir.sh b/scripts/style/check_style_dir.sh new file mode 100755 index 00000000..5250fe1e --- /dev/null +++ b/scripts/style/check_style_dir.sh @@ -0,0 +1,14 @@ +#!/bin/bash + +DIR="$1" + +if [ ! -d "$DIR" ]; then + echo "Directory $directory does not exist." + exit 1 +fi + +find "$DIR" -type f -name "*.lean" -print0 | while IFS= read -r -d '' file; do + if [ -f "$file" ]; then + ./scripts/style/check_style_file.sh "$file" + fi +done diff --git a/scripts/style/check_style_file.sh b/scripts/style/check_style_file.sh new file mode 100755 index 00000000..b6544b36 --- /dev/null +++ b/scripts/style/check_style_file.sh @@ -0,0 +1,10 @@ +#!/bin/bash + +errors=$(python3 ./.lake/packages/mathlib/scripts/lint-style.py $1 | grep -vE "ERR_COP") + +if [ -z "$errors" ]; then + exit 0 +else + echo "$errors" + exit 1 +fi