@@ -337,35 +337,54 @@ private def foundBVar (idx : Nat) : M Bool :=
337337private def saveBVar (idx : Nat) : M Unit := do
338338 modify fun s => { s with bvarsFound := s.bvarsFound.insert idx }
339339
340- private def getPatternFn? (pattern : Expr) : Option Expr :=
341- if !pattern.isApp && !pattern.isConst then
342- none
343- else match pattern.getAppFn with
344- | f@(.const declName _) => if isForbidden declName then none else some f
345- | f@(.fvar _) => some f
346- | _ => none
340+ inductive PatternArgKind where
341+ | /-- Argument is relevant for E-matching. -/
342+ relevant
343+ | /-- Instance implicit arguments are considered support and handled using `isDefEq`. -/
344+ instImplicit
345+ | /-- Proofs are ignored during E-matching. Lean is proof irrelevant. -/
346+ proof
347+ | /--
348+ Types and type formers are mostly ignored during E-matching, and processed using
349+ `isDefEq`. However, if the argument is of the form `C ..` where `C` is inductive type
350+ we process it as part of the pattern. Suppose we have `as bs : List α`, and a pattern
351+ candidate expression `as ++ bs`, i.e., `@HAppend.hAppend (List α) (List α) (List α) inst as bs`.
352+ If we completely ignore the types, the pattern will just be
353+ ```
354+ @HAppend.hAppend _ _ _ _ #1 #0
355+ ```
356+ This is not ideal because the E-matcher will try it in any goal that contains `++`,
357+ even if it does not even mention lists.
358+ -/
359+ typeFormer
360+ deriving Repr
361+
362+ def PatternArgKind.isSupport : PatternArgKind → Bool
363+ | .relevant => false
364+ | _ => true
347365
348366/--
349- Returns a bit-mask `mask` s.t. `mask[i]` is true if the corresponding argument is
367+ Returns an array `kinds` s.ts `kinds[i]` is the kind of the corresponding argument.
368+
350369- a type (that is not a proposition) or type former (which has forward dependencies) or
351370- a proof, or
352371- an instance implicit argument
353372
354- When `mask [i]`, we say the corresponding argument is a "support" argument.
373+ When `kinds [i].isSupport` is `true `, we say the corresponding argument is a "support" argument.
355374-/
356- def getPatternSupportMask (f : Expr) (numArgs : Nat) : MetaM (Array Bool ) := do
375+ def getPatternArgKinds (f : Expr) (numArgs : Nat) : MetaM (Array PatternArgKind ) := do
357376 let pinfos := (← getFunInfoNArgs f numArgs).paramInfo
358377 forallBoundedTelescope (← inferType f) numArgs fun xs _ => do
359378 xs.mapIdxM fun idx x => do
360379 if (← isProp x) then
361- return false
380+ return .relevant
362381 else if (← isProof x) then
363- return true
382+ return .proof
364383 else if (← isTypeFormer x) then
365384 if h : idx < pinfos.size then
366385 /-
367386 We originally wanted to ignore types and type formers in `grind` and treat them as supporting elements.
368- Thus, we would always return `true `. However, we changed our heuristic because of the following example:
387+ Thus, we wanted to always return `.typeFormer `. However, we changed our heuristic because of the following example:
369388 ```
370389 example {α} (f : α → Type) (a : α) (h : ∀ x, Nonempty (f x)) : Nonempty (f a) := by
371390 grind
@@ -374,59 +393,77 @@ def getPatternSupportMask (f : Expr) (numArgs : Nat) : MetaM (Array Bool) := do
374393 a type or type former is considered a supporting element only if it has forward dependencies.
375394 Note that this is not the case for `Nonempty`.
376395 -/
377- return pinfos[idx].hasFwdDeps
396+ if pinfos[idx].hasFwdDeps then return .typeFormer else return .relevant
378397 else
379- return true
398+ return .typeFormer
399+ else if (← x.fvarId!.getDecl).binderInfo matches .instImplicit then
400+ return .instImplicit
380401 else
381- return (← x.fvarId!.getDecl).binderInfo matches .instImplicit
402+ return .relevant
382403
383- private partial def go (pattern : Expr) : M Expr := do
404+ private def getPatternFn? (pattern : Expr) (inSupport : Bool) (argKind : PatternArgKind) : MetaM (Option Expr) := do
405+ if !pattern.isApp && !pattern.isConst then
406+ return none
407+ else match pattern.getAppFn with
408+ | f@(.const declName _) =>
409+ if isForbidden declName then
410+ return none
411+ if inSupport then
412+ if argKind matches .typeFormer | .relevant then
413+ if (← isInductive declName) then
414+ return some f
415+ return none
416+ return some f
417+ | f@(.fvar _) =>
418+ if inSupport then return none else return some f
419+ | _ =>
420+ return none
421+
422+ private partial def go (pattern : Expr) (inSupport : Bool) : M Expr := do
384423 if let some (e, k) := isOffsetPattern? pattern then
385- let e ← goArg e (isSupport := false )
424+ let e ← goArg e inSupport .relevant
386425 if e == dontCare then
387426 return dontCare
388427 else
389428 return mkOffsetPattern e k
390- let some f := getPatternFn? pattern
429+ let some f ← getPatternFn? pattern inSupport .relevant
391430 | throwError "invalid pattern, (non-forbidden) application expected{indentExpr pattern}"
392431 assert! f.isConst || f.isFVar
393432 unless f.isConstOf ``Grind.eqBwdPattern do
394- saveSymbol f.toHeadIndex
433+ saveSymbol f.toHeadIndex
395434 let mut args := pattern.getAppArgs.toVector
396- let supportMask ← getPatternSupportMask f args.size
435+ let patternArgKinds ← getPatternArgKinds f args.size
397436 for h : i in [:args.size] do
398437 let arg := args[i]
399- let isSupport := supportMask [i]?.getD false
400- args := args.set i (← goArg arg isSupport)
438+ let argKind := patternArgKinds [i]?.getD .relevant
439+ args := args.set i (← goArg arg (inSupport || argKind. isSupport) argKind )
401440 return mkAppN f args.toArray
402441where
403- goArg (arg : Expr) (isSupport : Bool) : M Expr := do
442+ goArg (arg : Expr) (inSupport : Bool) (argKind : PatternArgKind ) : M Expr := do
404443 if !arg.hasLooseBVars then
405444 if arg.hasMVar then
406445 pure dontCare
407446 else
408447 pure <| mkGroundPattern arg
409448 else match arg with
410449 | .bvar idx =>
411- if isSupport && (← foundBVar idx) then
450+ if inSupport && (← foundBVar idx) then
412451 pure dontCare
413452 else
414453 saveBVar idx
415454 pure arg
416455 | _ =>
417- if isSupport then
418- pure dontCare
419- else if let some _ := getPatternFn? arg then
420- go arg
456+ if let some _ ← getPatternFn? arg inSupport argKind then
457+ go arg inSupport
421458 else
422459 pure dontCare
423460
424461def main (patterns : List Expr) : MetaM (List Expr × List HeadIndex × Std.HashSet Nat) := do
425- let (patterns, s) ← patterns.mapM go |>.run {}
462+ let (patterns, s) ← patterns.mapM (go (inSupport := false )) |>.run {}
426463 return (patterns, s.symbols.toList, s.bvarsFound)
427464
428465def normalizePattern (e : Expr) : M Expr := do
429- go e
466+ go e (inSupport := false )
430467
431468end NormalizePattern
432469
@@ -668,11 +705,11 @@ private def isPatternFnCandidate (f : Expr) : CollectorM Bool := do
668705 | _ => return false
669706
670707private def addNewPattern (p : Expr) : CollectorM Unit := do
671- trace[grind.ematch.pattern.search ] "found pattern: {ppPattern p}"
708+ trace[grind.debug. ematch.pattern] "found pattern: {ppPattern p}"
672709 let bvarsFound := (← getThe NormalizePattern.State).bvarsFound
673710 let done := (← checkCoverage (← read).proof (← read).xs.size bvarsFound) matches .ok
674711 if done then
675- trace[grind.ematch.pattern.search ] "found full coverage"
712+ trace[grind.debug. ematch.pattern] "found full coverage"
676713 modify fun s => { s with patterns := s.patterns.push p, done }
677714
678715/-- Collect the pattern (i.e., de Bruijn) variables in the given pattern. -/
@@ -696,10 +733,11 @@ Returns `true` if pattern `p` contains a child `c` such that
6967333- `c` is not an offset pattern.
6977344- `c` is not a bound variable.
698735-/
699- private def hasChildWithSameNewBVars (p : Expr) (supportMask : Array Bool) (alreadyFound : Std.HashSet Nat) : CoreM Bool := do
736+ private def hasChildWithSameNewBVars (p : Expr)
737+ (argKinds : Array NormalizePattern.PatternArgKind) (alreadyFound : Std.HashSet Nat) : CoreM Bool := do
700738 let s := diff (collectPatternBVars p) alreadyFound
701- for arg in p.getAppArgs, support in supportMask do
702- unless support do
739+ for arg in p.getAppArgs, argKind in argKinds do
740+ unless argKind.isSupport do
703741 unless arg.isBVar do
704742 unless isOffsetPattern? arg |>.isSome do
705743 let sArg := diff (collectPatternBVars arg) alreadyFound
@@ -711,31 +749,33 @@ private partial def collect (e : Expr) : CollectorM Unit := do
711749 if (← get).done then return ()
712750 match e with
713751 | .app .. =>
752+ trace[grind.debug.ematch.pattern] "collect: {e}"
714753 let f := e.getAppFn
715- let supportMask ← NormalizePattern.getPatternSupportMask f e.getAppNumArgs
754+ let argKinds ← NormalizePattern.getPatternArgKinds f e.getAppNumArgs
716755 if (← isPatternFnCandidate f) then
717756 let saved ← getThe NormalizePattern.State
718757 try
719- trace[grind.ematch.pattern.search ] "candidate: {e}"
758+ trace[grind.debug. ematch.pattern] "candidate: {e}"
720759 let p := e.abstract (← read).xs
721760 unless p.hasLooseBVars do
722- trace[grind.ematch.pattern.search ] "skip, does not contain pattern variables"
761+ trace[grind.debug. ematch.pattern] "skip, does not contain pattern variables"
723762 return ()
724763 let p ← NormalizePattern.normalizePattern p
725764 if saved.bvarsFound.size < (← getThe NormalizePattern.State).bvarsFound.size then
726- unless (← hasChildWithSameNewBVars p supportMask saved.bvarsFound) do
765+ unless (← hasChildWithSameNewBVars p argKinds saved.bvarsFound) do
727766 addNewPattern p
728767 return ()
729- trace[grind.ematch.pattern.search ] "skip, no new variables covered"
768+ trace[grind.debug. ematch.pattern] "skip, no new variables covered"
730769 -- restore state and continue search
731770 set saved
732- catch _ =>
733- trace[grind.ematch.pattern.search ] "skip, exception during normalization"
771+ catch ex =>
772+ trace[grind.debug. ematch.pattern] "skip, exception during normalization{indentD ex.toMessageData} "
734773 -- restore state and continue search
735774 set saved
736775 let args := e.getAppArgs
737- for arg in args, support in supportMask do
738- unless support do
776+ for arg in args, argKind in argKinds do
777+ trace[grind.debug.ematch.pattern] "arg: {arg}, support: {argKind.isSupport}"
778+ unless argKind.isSupport do
739779 collect arg
740780 | .forallE _ d b _ =>
741781 if (← pure e.isArrow <&&> isProp d <&&> isProp b) then
@@ -746,6 +786,7 @@ private partial def collect (e : Expr) : CollectorM Unit := do
746786private def collectPatterns? (proof : Expr) (xs : Array Expr) (searchPlaces : Array Expr) : MetaM (Option (List Expr × List HeadIndex)) := do
747787 let go : CollectorM (Option (List Expr)) := do
748788 for place in searchPlaces do
789+ trace[grind.debug.ematch.pattern] "place: {place}"
749790 let place ← preprocessPattern place
750791 collect place
751792 if (← get).done then
@@ -782,8 +823,8 @@ where
782823 return some e
783824 else
784825 let args := e.getAppArgs
785- for arg in args, flag in (← NormalizePattern.getPatternSupportMask f args.size) do
786- unless flag do
826+ for arg in args, argKind in (← NormalizePattern.getPatternArgKinds f args.size) do
827+ unless argKind.isSupport do
787828 if let some r ← visit? arg then
788829 return r
789830 return none
0 commit comments