Skip to content

Commit

Permalink
feat: progress? to say which theorem has been used (#277)
Browse files Browse the repository at this point in the history
Like `simp?` and similar, `progress?` will provide what theorem was used
to make the goal progress.

Signed-off-by: Ryan Lahfa <[email protected]>
Co-authored-by: Ryan Lahfa <[email protected]>
Co-authored-by: Son HO <[email protected]>
  • Loading branch information
3 people committed Jul 5, 2024
1 parent e7f5c00 commit 71f2c2a
Showing 1 changed file with 51 additions and 10 deletions.
61 changes: 51 additions & 10 deletions backends/lean/Base/Progress/Progress.lean
Original file line number Diff line number Diff line change
Expand Up @@ -33,9 +33,15 @@ inductive TheoremOrLocal where
| Theorem (thName : Name)
| Local (asm : LocalDecl)

structure Stats where
usedTheorem : TheoremOrLocal

instance : ToMessageData TheoremOrLocal where
toMessageData := λ x => match x with | .Theorem thName => m!"{thName}" | .Local asm => m!"{asm.userName}"

instance : Coe TheoremOrLocal Name where
coe := λ x => match x with | .Theorem thName => thName | .Local asm => asm.userName

/- Type to propagate the errors of `progressWith`.
We need this because we use the exceptions to backtrack, when trying to
use the assumptions for instance. When there is actually an error we want
Expand Down Expand Up @@ -234,7 +240,7 @@ def tryLookupApply (keep : Option Name) (ids : Array (Option Name)) (splitPost :

-- The array of ids are identifiers to use when introducing fresh variables
def progressAsmsOrLookupTheorem (keep : Option Name) (withTh : Option TheoremOrLocal)
(ids : Array (Option Name)) (splitPost : Bool) (asmTac : TacticM Unit) : TacticM Unit := do
(ids : Array (Option Name)) (splitPost : Bool) (asmTac : TacticM Unit) : TacticM TheoremOrLocal := do
withMainContext do
-- Retrieve the goal
let mgoal ← Tactic.getMainGoal
Expand Down Expand Up @@ -262,7 +268,7 @@ def progressAsmsOrLookupTheorem (keep : Option Name) (withTh : Option TheoremOrL
match withTh with
| some th => do
match ← progressWith fExpr th keep ids splitPost asmTac with
| .Ok => return ()
| .Ok => return th
| .Error msg => throwError msg
| none =>
-- Try all the assumptions one by one and if it fails try to lookup a theorem.
Expand All @@ -272,7 +278,7 @@ def progressAsmsOrLookupTheorem (keep : Option Name) (withTh : Option TheoremOrL
trace[Progress] "Trying assumption: {decl.userName} : {decl.type}"
let res ← do try progressWith fExpr (.Local decl) keep ids splitPost asmTac catch _ => continue
match res with
| .Ok => return ()
| .Ok => return (.Local decl)
| .Error msg => throwError msg
-- It failed: lookup the pspec theorems which match the expression *only
-- if the function is a constant*
Expand All @@ -293,7 +299,7 @@ def progressAsmsOrLookupTheorem (keep : Option Name) (withTh : Option TheoremOrL
pure (thNames.map fun th => TheoremOrLocal.Theorem th)
-- Try the theorems one by one
for pspec in pspecs do
if ← tryLookupApply keep ids splitPost asmTac fExpr "pspec theorem" pspec then return ()
if ← tryLookupApply keep ids splitPost asmTac fExpr "pspec theorem" pspec then return pspec
else pure ()
-- It failed: try to use the recursive assumptions
trace[Progress] "Failed using a pspec theorem: trying to use a recursive assumption"
Expand All @@ -306,14 +312,14 @@ def progressAsmsOrLookupTheorem (keep : Option Name) (withTh : Option TheoremOrL
trace[Progress] "Trying recursive assumption: {decl.userName} : {decl.type}"
let res ← do try progressWith fExpr (.Local decl) keep ids splitPost asmTac catch _ => continue
match res with
| .Ok => return ()
| .Ok => return (.Local decl)
| .Error msg => throwError msg
-- Nothing worked: failed
throwError "Progress failed"

syntax progressArgs := ("keep" (ident <|> "_"))? ("with" ident)? ("as" " ⟨ " (ident <|> "_"),* " .."? " ⟩")?

def evalProgress (args : TSyntax `Progress.progressArgs) : TacticM Unit := do
def evalProgress (args : TSyntax `Progress.progressArgs) : TacticM Stats := do
let args := args.raw
-- Process the arguments to retrieve the identifiers to use
trace[Progress] "Progress arguments: {args}"
Expand Down Expand Up @@ -370,14 +376,26 @@ def evalProgress (args : TSyntax `Progress.progressArgs) : TacticM Unit := do
Arith.scalarTac false
else
throwError "Not a linear arithmetic goal"
progressAsmsOrLookupTheorem keep withArg ids splitPost (
let usedTheorem ← progressAsmsOrLookupTheorem keep withArg ids splitPost (
withMainContext do
trace[Progress] "trying to solve assumption: {← getMainGoal}"
firstTac [assumptionTac, scalarTac])
trace[Diverge] "Progress done"
trace[Progress] "Progress done"
return {
usedTheorem
}

elab (name := progress) "progress" args:progressArgs : tactic =>
evalProgress args *> return ()

elab "progress" args:progressArgs : tactic =>
evalProgress args
elab tk:"progress?" args:progressArgs : tactic => do
let stats ← evalProgress args
let mut stxArgs := args.raw
if stxArgs[1].isNone then
let withArg := mkNullNode #[mkAtom "with", mkIdent stats.usedTheorem]
stxArgs := stxArgs.setArg 1 withArg
let tac := mkNode `Progress.progress #[mkAtom "progress", stxArgs]
Meta.Tactic.TryThis.addSuggestion tk tac (origSpan? := ← getRef)

namespace Test
open Primitives Result
Expand All @@ -386,6 +404,8 @@ namespace Test
-- set_option trace.Progress true
-- set_option pp.rawOnError true

set_option says.verify true

-- The following commands display the databases of theorems
-- #eval showStoredPSpec
open alloc.vec
Expand All @@ -397,6 +417,13 @@ namespace Test
progress keep _ as ⟨ z, h1 .. ⟩
simp [*, h1]

example {ty} {x y : Scalar ty}
(hmin : Scalar.min ty ≤ x.val + y.val)
(hmax : x.val + y.val ≤ Scalar.max ty) :
∃ z, x + y = ok z ∧ z.val = x.val + y.val := by
progress? keep _ as ⟨ z, h1 .. ⟩ says progress keep _ with Primitives.Scalar.add_spec as ⟨ z, h1 .. ⟩
simp [*, h1]

example {ty} {x y : Scalar ty}
(hmin : Scalar.min ty ≤ x.val + y.val)
(hmax : x.val + y.val ≤ Scalar.max ty) :
Expand Down Expand Up @@ -432,6 +459,14 @@ namespace Test
progress
simp [*]


example {α : Type} (v: Vec α) (i: Usize) (x : α)
(hbounds : i.val < v.length) :
∃ nv, v.update_usize i x = ok nv ∧
nv.val = v.val.update i.val x := by
progress? says progress with Primitives.alloc.vec.Vec.update_usize_spec
simp [*]

/- Checking that progress can handle nested blocks -/
example {α : Type} (v: Vec α) (i: Usize) (x : α)
(hbounds : i.val < v.length) :
Expand All @@ -454,6 +489,12 @@ namespace Test
f x = ok () := by
progress


example {x : U32}
(f : U32 → Result Unit) (h : ∀ x, f x = .ok ()) :
f x = ok () := by
progress? says progress with h

/- The use of `right` introduces a meta-variable in the goal, that we
need to instantiate (otherwise `progress` gets stuck) -/
example {ty} {x y : Scalar ty}
Expand Down

0 comments on commit 71f2c2a

Please sign in to comment.