Skip to content

Commit af4c693

Browse files
authored
feat: improve E-matching pattern inference in grind (leanprover#8196)
This PR improves the E-matching pattern inference procedure in `grind`. Consider the following theorem: ```lean @[grind →] theorem eq_empty_of_append_eq_empty {xs ys : Array α} (h : xs ++ ys = #[]) : xs = #[] ∧ ys = #[] := append_eq_empty_iff.mp h ``` Before this PR, `grind` inferred the following pattern: ```lean @HAppend.hAppend _ _ _ _ leanprover#2 #1 ``` Note that this pattern would match any `++` application, even if it had nothing to do with arrays. With this PR, the inferred pattern becomes: ```lean @HAppend.hAppend (Array leanprover#3) (Array _) (Array _) _ leanprover#2 #1 ``` With the new pattern, the theorem will not be considered by `grind` for goals that do not involve `Array`s.
1 parent 9277555 commit af4c693

File tree

13 files changed

+164
-101
lines changed

13 files changed

+164
-101
lines changed

src/Lean/Meta/Tactic/Grind.lean

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,6 @@ builtin_initialize registerTraceClass `grind.eqc
4242
builtin_initialize registerTraceClass `grind.internalize
4343
builtin_initialize registerTraceClass `grind.ematch
4444
builtin_initialize registerTraceClass `grind.ematch.pattern
45-
builtin_initialize registerTraceClass `grind.ematch.pattern.search
4645
builtin_initialize registerTraceClass `grind.ematch.instance
4746
builtin_initialize registerTraceClass `grind.ematch.instance.assignment
4847
builtin_initialize registerTraceClass `grind.eqResolution
@@ -71,6 +70,7 @@ builtin_initialize registerTraceClass `grind.debug.final
7170
builtin_initialize registerTraceClass `grind.debug.forallPropagator
7271
builtin_initialize registerTraceClass `grind.debug.split
7372
builtin_initialize registerTraceClass `grind.debug.canon
73+
builtin_initialize registerTraceClass `grind.debug.ematch.activate
7474
builtin_initialize registerTraceClass `grind.debug.ematch.pattern
7575
builtin_initialize registerTraceClass `grind.debug.beta
7676
builtin_initialize registerTraceClass `grind.debug.internalize
@@ -81,7 +81,6 @@ builtin_initialize registerTraceClass `grind.debug.mbtc
8181
builtin_initialize registerTraceClass `grind.debug.ematch
8282
builtin_initialize registerTraceClass `grind.debug.proveEq
8383
builtin_initialize registerTraceClass `grind.debug.pushNewFact
84-
builtin_initialize registerTraceClass `grind.debug.ematch.activate
8584
builtin_initialize registerTraceClass `grind.debug.appMap
8685
builtin_initialize registerTraceClass `grind.debug.ext
8786

src/Lean/Meta/Tactic/Grind/EMatch.lean

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -300,7 +300,7 @@ private partial def instantiateTheorem (c : Choice) : M Unit := withDefault do w
300300
let report : M Unit := do
301301
reportIssue! "type error constructing proof for {← thm.origin.pp}\nwhen assigning metavariable {mvars[i]} with {indentExpr v}\n{← mkHasTypeButIsExpectedMsg vType mvarIdType}"
302302
unless (← withDefault <| isDefEq mvarIdType vType) do
303-
let some heq ← proveEq? vType mvarIdType
303+
let some heq ← withoutReportingMVarIssues <| proveEq? vType mvarIdType
304304
| report
305305
return ()
306306
/-

src/Lean/Meta/Tactic/Grind/EMatchTheorem.lean

Lines changed: 88 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -337,35 +337,54 @@ private def foundBVar (idx : Nat) : M Bool :=
337337
private def saveBVar (idx : Nat) : M Unit := do
338338
modify fun s => { s with bvarsFound := s.bvarsFound.insert idx }
339339

340-
private def getPatternFn? (pattern : Expr) : Option Expr :=
341-
if !pattern.isApp && !pattern.isConst then
342-
none
343-
else match pattern.getAppFn with
344-
| f@(.const declName _) => if isForbidden declName then none else some f
345-
| f@(.fvar _) => some f
346-
| _ => none
340+
inductive PatternArgKind where
341+
| /-- Argument is relevant for E-matching. -/
342+
relevant
343+
| /-- Instance implicit arguments are considered support and handled using `isDefEq`. -/
344+
instImplicit
345+
| /-- Proofs are ignored during E-matching. Lean is proof irrelevant. -/
346+
proof
347+
| /--
348+
Types and type formers are mostly ignored during E-matching, and processed using
349+
`isDefEq`. However, if the argument is of the form `C ..` where `C` is inductive type
350+
we process it as part of the pattern. Suppose we have `as bs : List α`, and a pattern
351+
candidate expression `as ++ bs`, i.e., `@HAppend.hAppend (List α) (List α) (List α) inst as bs`.
352+
If we completely ignore the types, the pattern will just be
353+
```
354+
@HAppend.hAppend _ _ _ _ #1 #0
355+
```
356+
This is not ideal because the E-matcher will try it in any goal that contains `++`,
357+
even if it does not even mention lists.
358+
-/
359+
typeFormer
360+
deriving Repr
361+
362+
def PatternArgKind.isSupport : PatternArgKind → Bool
363+
| .relevant => false
364+
| _ => true
347365

348366
/--
349-
Returns a bit-mask `mask` s.t. `mask[i]` is true if the corresponding argument is
367+
Returns an array `kinds` s.ts `kinds[i]` is the kind of the corresponding argument.
368+
350369
- a type (that is not a proposition) or type former (which has forward dependencies) or
351370
- a proof, or
352371
- an instance implicit argument
353372
354-
When `mask[i]`, we say the corresponding argument is a "support" argument.
373+
When `kinds[i].isSupport` is `true`, we say the corresponding argument is a "support" argument.
355374
-/
356-
def getPatternSupportMask (f : Expr) (numArgs : Nat) : MetaM (Array Bool) := do
375+
def getPatternArgKinds (f : Expr) (numArgs : Nat) : MetaM (Array PatternArgKind) := do
357376
let pinfos := (← getFunInfoNArgs f numArgs).paramInfo
358377
forallBoundedTelescope (← inferType f) numArgs fun xs _ => do
359378
xs.mapIdxM fun idx x => do
360379
if (← isProp x) then
361-
return false
380+
return .relevant
362381
else if (← isProof x) then
363-
return true
382+
return .proof
364383
else if (← isTypeFormer x) then
365384
if h : idx < pinfos.size then
366385
/-
367386
We originally wanted to ignore types and type formers in `grind` and treat them as supporting elements.
368-
Thus, we would always return `true`. However, we changed our heuristic because of the following example:
387+
Thus, we wanted to always return `.typeFormer`. However, we changed our heuristic because of the following example:
369388
```
370389
example {α} (f : α → Type) (a : α) (h : ∀ x, Nonempty (f x)) : Nonempty (f a) := by
371390
grind
@@ -374,59 +393,77 @@ def getPatternSupportMask (f : Expr) (numArgs : Nat) : MetaM (Array Bool) := do
374393
a type or type former is considered a supporting element only if it has forward dependencies.
375394
Note that this is not the case for `Nonempty`.
376395
-/
377-
return pinfos[idx].hasFwdDeps
396+
if pinfos[idx].hasFwdDeps then return .typeFormer else return .relevant
378397
else
379-
return true
398+
return .typeFormer
399+
else if (← x.fvarId!.getDecl).binderInfo matches .instImplicit then
400+
return .instImplicit
380401
else
381-
return (← x.fvarId!.getDecl).binderInfo matches .instImplicit
402+
return .relevant
382403

383-
private partial def go (pattern : Expr) : M Expr := do
404+
private def getPatternFn? (pattern : Expr) (inSupport : Bool) (argKind : PatternArgKind) : MetaM (Option Expr) := do
405+
if !pattern.isApp && !pattern.isConst then
406+
return none
407+
else match pattern.getAppFn with
408+
| f@(.const declName _) =>
409+
if isForbidden declName then
410+
return none
411+
if inSupport then
412+
if argKind matches .typeFormer | .relevant then
413+
if (← isInductive declName) then
414+
return some f
415+
return none
416+
return some f
417+
| f@(.fvar _) =>
418+
if inSupport then return none else return some f
419+
| _ =>
420+
return none
421+
422+
private partial def go (pattern : Expr) (inSupport : Bool) : M Expr := do
384423
if let some (e, k) := isOffsetPattern? pattern then
385-
let e ← goArg e (isSupport := false)
424+
let e ← goArg e inSupport .relevant
386425
if e == dontCare then
387426
return dontCare
388427
else
389428
return mkOffsetPattern e k
390-
let some f := getPatternFn? pattern
429+
let some f getPatternFn? pattern inSupport .relevant
391430
| throwError "invalid pattern, (non-forbidden) application expected{indentExpr pattern}"
392431
assert! f.isConst || f.isFVar
393432
unless f.isConstOf ``Grind.eqBwdPattern do
394-
saveSymbol f.toHeadIndex
433+
saveSymbol f.toHeadIndex
395434
let mut args := pattern.getAppArgs.toVector
396-
let supportMaskgetPatternSupportMask f args.size
435+
let patternArgKindsgetPatternArgKinds f args.size
397436
for h : i in [:args.size] do
398437
let arg := args[i]
399-
let isSupport := supportMask[i]?.getD false
400-
args := args.set i (← goArg arg isSupport)
438+
let argKind := patternArgKinds[i]?.getD .relevant
439+
args := args.set i (← goArg arg (inSupport || argKind.isSupport) argKind)
401440
return mkAppN f args.toArray
402441
where
403-
goArg (arg : Expr) (isSupport : Bool) : M Expr := do
442+
goArg (arg : Expr) (inSupport : Bool) (argKind : PatternArgKind) : M Expr := do
404443
if !arg.hasLooseBVars then
405444
if arg.hasMVar then
406445
pure dontCare
407446
else
408447
pure <| mkGroundPattern arg
409448
else match arg with
410449
| .bvar idx =>
411-
if isSupport && (← foundBVar idx) then
450+
if inSupport && (← foundBVar idx) then
412451
pure dontCare
413452
else
414453
saveBVar idx
415454
pure arg
416455
| _ =>
417-
if isSupport then
418-
pure dontCare
419-
else if let some _ := getPatternFn? arg then
420-
go arg
456+
if let some _ ← getPatternFn? arg inSupport argKind then
457+
go arg inSupport
421458
else
422459
pure dontCare
423460

424461
def main (patterns : List Expr) : MetaM (List Expr × List HeadIndex × Std.HashSet Nat) := do
425-
let (patterns, s) ← patterns.mapM go |>.run {}
462+
let (patterns, s) ← patterns.mapM (go (inSupport := false)) |>.run {}
426463
return (patterns, s.symbols.toList, s.bvarsFound)
427464

428465
def normalizePattern (e : Expr) : M Expr := do
429-
go e
466+
go e (inSupport := false)
430467

431468
end NormalizePattern
432469

@@ -668,11 +705,11 @@ private def isPatternFnCandidate (f : Expr) : CollectorM Bool := do
668705
| _ => return false
669706

670707
private def addNewPattern (p : Expr) : CollectorM Unit := do
671-
trace[grind.ematch.pattern.search] "found pattern: {ppPattern p}"
708+
trace[grind.debug.ematch.pattern] "found pattern: {ppPattern p}"
672709
let bvarsFound := (← getThe NormalizePattern.State).bvarsFound
673710
let done := (← checkCoverage (← read).proof (← read).xs.size bvarsFound) matches .ok
674711
if done then
675-
trace[grind.ematch.pattern.search] "found full coverage"
712+
trace[grind.debug.ematch.pattern] "found full coverage"
676713
modify fun s => { s with patterns := s.patterns.push p, done }
677714

678715
/-- Collect the pattern (i.e., de Bruijn) variables in the given pattern. -/
@@ -696,10 +733,11 @@ Returns `true` if pattern `p` contains a child `c` such that
696733
3- `c` is not an offset pattern.
697734
4- `c` is not a bound variable.
698735
-/
699-
private def hasChildWithSameNewBVars (p : Expr) (supportMask : Array Bool) (alreadyFound : Std.HashSet Nat) : CoreM Bool := do
736+
private def hasChildWithSameNewBVars (p : Expr)
737+
(argKinds : Array NormalizePattern.PatternArgKind) (alreadyFound : Std.HashSet Nat) : CoreM Bool := do
700738
let s := diff (collectPatternBVars p) alreadyFound
701-
for arg in p.getAppArgs, support in supportMask do
702-
unless support do
739+
for arg in p.getAppArgs, argKind in argKinds do
740+
unless argKind.isSupport do
703741
unless arg.isBVar do
704742
unless isOffsetPattern? arg |>.isSome do
705743
let sArg := diff (collectPatternBVars arg) alreadyFound
@@ -711,31 +749,33 @@ private partial def collect (e : Expr) : CollectorM Unit := do
711749
if (← get).done then return ()
712750
match e with
713751
| .app .. =>
752+
trace[grind.debug.ematch.pattern] "collect: {e}"
714753
let f := e.getAppFn
715-
let supportMask ← NormalizePattern.getPatternSupportMask f e.getAppNumArgs
754+
let argKinds ← NormalizePattern.getPatternArgKinds f e.getAppNumArgs
716755
if (← isPatternFnCandidate f) then
717756
let saved ← getThe NormalizePattern.State
718757
try
719-
trace[grind.ematch.pattern.search] "candidate: {e}"
758+
trace[grind.debug.ematch.pattern] "candidate: {e}"
720759
let p := e.abstract (← read).xs
721760
unless p.hasLooseBVars do
722-
trace[grind.ematch.pattern.search] "skip, does not contain pattern variables"
761+
trace[grind.debug.ematch.pattern] "skip, does not contain pattern variables"
723762
return ()
724763
let p ← NormalizePattern.normalizePattern p
725764
if saved.bvarsFound.size < (← getThe NormalizePattern.State).bvarsFound.size then
726-
unless (← hasChildWithSameNewBVars p supportMask saved.bvarsFound) do
765+
unless (← hasChildWithSameNewBVars p argKinds saved.bvarsFound) do
727766
addNewPattern p
728767
return ()
729-
trace[grind.ematch.pattern.search] "skip, no new variables covered"
768+
trace[grind.debug.ematch.pattern] "skip, no new variables covered"
730769
-- restore state and continue search
731770
set saved
732-
catch _ =>
733-
trace[grind.ematch.pattern.search] "skip, exception during normalization"
771+
catch ex =>
772+
trace[grind.debug.ematch.pattern] "skip, exception during normalization{indentD ex.toMessageData}"
734773
-- restore state and continue search
735774
set saved
736775
let args := e.getAppArgs
737-
for arg in args, support in supportMask do
738-
unless support do
776+
for arg in args, argKind in argKinds do
777+
trace[grind.debug.ematch.pattern] "arg: {arg}, support: {argKind.isSupport}"
778+
unless argKind.isSupport do
739779
collect arg
740780
| .forallE _ d b _ =>
741781
if (← pure e.isArrow <&&> isProp d <&&> isProp b) then
@@ -746,6 +786,7 @@ private partial def collect (e : Expr) : CollectorM Unit := do
746786
private def collectPatterns? (proof : Expr) (xs : Array Expr) (searchPlaces : Array Expr) : MetaM (Option (List Expr × List HeadIndex)) := do
747787
let go : CollectorM (Option (List Expr)) := do
748788
for place in searchPlaces do
789+
trace[grind.debug.ematch.pattern] "place: {place}"
749790
let place ← preprocessPattern place
750791
collect place
751792
if (← get).done then
@@ -782,8 +823,8 @@ where
782823
return some e
783824
else
784825
let args := e.getAppArgs
785-
for arg in args, flag in (← NormalizePattern.getPatternSupportMask f args.size) do
786-
unless flag do
826+
for arg in args, argKind in (← NormalizePattern.getPatternArgKinds f args.size) do
827+
unless argKind.isSupport do
787828
if let some r ← visit? arg then
788829
return r
789830
return none

src/Lean/Meta/Tactic/Grind/Internalize.lean

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -313,7 +313,8 @@ private partial def internalizeImpl (e : Expr) (generation : Nat) (parent? : Opt
313313
mkENode e generation
314314
activateTheoremPatterns declName generation
315315
| .mvar .. =>
316-
reportIssue! "unexpected metavariable during internalization{indentExpr e}\n`grind` is not supposed to be used in goals containing metavariables."
316+
if (← reportMVarInternalization) then
317+
reportIssue! "unexpected metavariable during internalization{indentExpr e}\n`grind` is not supposed to be used in goals containing metavariables."
317318
mkENode' e generation
318319
| .mdata .. =>
319320
reportIssue! "unexpected metadata found during internalization{indentExpr e}\n`grind` uses a pre-processing step that eliminates metadata"

src/Lean/Meta/Tactic/Grind/Types.lean

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,13 @@ structure Context where
6868
when performing lookahead.
6969
-/
7070
cheapCases : Bool := false
71+
/-
72+
If `reportMVarIssue` is `true`, `grind` reports an issue when internalizing metavariables.
73+
The default value is `true`. We set it to `false` when invoking `proveEq` from the `instantiateTheorem`
74+
at in the E-matching module. There, the terms may contain metavariables, and we don't want to bother
75+
user with "false-alarms". If the instantiation fails, we produce a more informative issue anyways.
76+
-/
77+
reportMVarIssue : Bool := true
7178

7279
/-- Key for the congruence theorem cache. -/
7380
structure CongrTheoremCacheKey where
@@ -147,6 +154,16 @@ instance : Nonempty MethodsRef := MethodsRefPointed.property
147154

148155
abbrev GrindM := ReaderT MethodsRef $ ReaderT Context $ StateRefT State MetaM
149156

157+
@[inline] def mapGrindM [MonadControlT GrindM m] [Monad m] (f : {α : Type} → GrindM α → GrindM α) {α} (x : m α) : m α :=
158+
controlAt GrindM fun runInBase => f <| runInBase x
159+
160+
/--
161+
`withoutReportingMVarIssues x` executes `x` without reporting metavariables found during internalization.
162+
See comment at `Grind.Context.reportMVarIssue` for additional details.
163+
-/
164+
def withoutReportingMVarIssues [MonadControlT GrindM m] [Monad m] : m α → m α :=
165+
mapGrindM <| withTheReader Grind.Context fun ctx => { ctx with reportMVarIssue := false }
166+
150167
/-- Returns the user-defined configuration options -/
151168
def getConfig : GrindM Grind.Config :=
152169
return (← readThe Context).config
@@ -177,6 +194,9 @@ def getMainDeclName : GrindM Name :=
177194
def cheapCasesOnly : GrindM Bool :=
178195
return (← readThe Context).cheapCases
179196

197+
def reportMVarInternalization : GrindM Bool :=
198+
return (← readThe Context).reportMVarIssue
199+
180200
def saveEMatchTheorem (thm : EMatchTheorem) : GrindM Unit := do
181201
if (← getConfig).trace then
182202
modify fun s => { s with trace.thms := s.trace.thms.insert { origin := thm.origin, kind := thm.kind } }

tests/lean/grind/hashmap_list.lean

Lines changed: 0 additions & 18 deletions
This file was deleted.

tests/lean/grind/mvar.lean

Lines changed: 0 additions & 14 deletions
This file was deleted.

0 commit comments

Comments
 (0)