diff --git a/src/Lean/Elab/PreDefinition/Main.lean b/src/Lean/Elab/PreDefinition/Main.lean index 471ad586ea70..562c42508f52 100644 --- a/src/Lean/Elab/PreDefinition/Main.lean +++ b/src/Lean/Elab/PreDefinition/Main.lean @@ -18,15 +18,14 @@ private def addAndCompilePartial (preDefs : Array PreDefinition) (useSorry := fa for preDef in preDefs do trace[Elab.definition] "processing {preDef.declName}" let all := preDefs.toList.map (·.declName) - forallTelescope preDef.type fun xs type => do - let value ← if useSorry then - mkLambdaFVars xs (← mkSorry type (synthetic := true)) - else - liftM <| mkInhabitantFor preDef.declName xs type - addNonRec { preDef with - kind := DefKind.«opaque» - value - } (all := all) + let value ← if useSorry then + mkSorry preDef.type (synthetic := true) + else + liftM <| mkInhabitantFor preDef.declName preDef.type + addNonRec { preDef with + kind := DefKind.«opaque» + value + } (all := all) addAndCompilePartialRec preDefs private def isNonRecursive (preDef : PreDefinition) : Bool := diff --git a/src/Lean/Elab/PreDefinition/MkInhabitant.lean b/src/Lean/Elab/PreDefinition/MkInhabitant.lean index 4006a9d28493..21179018ae20 100644 --- a/src/Lean/Elab/PreDefinition/MkInhabitant.lean +++ b/src/Lean/Elab/PreDefinition/MkInhabitant.lean @@ -9,64 +9,118 @@ import Lean.PrettyPrinter namespace Lean.Elab open Meta -private def withInhabitedInstances (xs : Array Expr) (k : Array Expr → MetaM α) : MetaM α := do - let rec go (i : Nat) (insts : Array Expr) : MetaM α := do +private structure InhabitantCache (useOfNonempty : Bool) where + /-- Map from types to inhabitants. -/ + inhabitants : ExprMap Expr := {} + /-- Local instances to eliminate with `mkLetFVars`. -/ + localInstances : Array Expr := #[] + /-- Worklist of expressions to try to get more inhabitants out of. + The types of these are functions `α → β`, and in `resolve` we try to inhabit `α`. -/ + unresolved : Array Expr := #[] + +private def InhabitantCache.withInhabitant (cache : InhabitantCache useOfNonempty) (x : Expr) + (k : InhabitantCache useOfNonempty → MetaM α) : MetaM α := do + let xTy ← whnfCore (← inferType x) + if cache.inhabitants.contains xTy then + k cache + else + -- Add `Inhabited` instance to help find inhabitants through typeclass inference later. + let u ← getLevel xTy + let instTy := mkApp (.const ``Inhabited [u]) xTy + let instVal := mkApp2 (.const ``Inhabited.mk [u]) xTy x + withLetDecl `inst instTy instVal fun inst => + k { inhabitants := cache.inhabitants.insert xTy x + localInstances := cache.localInstances.push inst + unresolved := if xTy.isForall then cache.unresolved.push x else cache.unresolved } + +private def InhabitantCache.withInhabitants (cache : InhabitantCache useOfNonempty) (xs : Array Expr) + (k : InhabitantCache useOfNonempty → MetaM α) : MetaM α := do + let rec go (i : Nat) cache : MetaM α := do if h : i < xs.size then - let x := xs[i] - let xTy ← inferType x - let u ← getLevel xTy - let instTy := mkApp (.const ``Inhabited [u]) xTy - let instVal := mkApp2 (.const ``Inhabited.mk [u]) xTy x - withLetDecl `inst instTy instVal fun inst => - go (i + 1) (insts.push inst) + cache.withInhabitant xs[i] (go (i + 1)) else - k insts - go 0 #[] + k cache + go 0 cache + +/-- Constructs an inhabitant if possible. +Eliminates the `localInstances` let variables introduced by `withInhabitant`. -/ +private def InhabitantCache.mkInhabitant? (cache : InhabitantCache useOfNonempty) (type : Expr) : + MetaM (Option Expr) := do + if let some x := cache.inhabitants[type]? then + return some x + else + try + let x ← if useOfNonempty then mkOfNonempty type else mkDefault type + mkLetFVars (usedLetOnly := true) cache.localInstances x + catch _ => + return none -private def mkInhabitant? (type : Expr) (useOfNonempty : Bool) : MetaM (Option Expr) := do - try - if useOfNonempty then - return some (← mkOfNonempty type) +/-- Tries to find new inhabitants from the worklist. Does a single pass, instantiating foralls as far as possible. -/ +private def InhabitantCache.resolve (cache : InhabitantCache useOfNonempty) + (kOk : InhabitantCache useOfNonempty → MetaM α) (kFail : MetaM α) : MetaM α := do + let unresolved := cache.unresolved + let rec go (i : Nat) (cache : InhabitantCache useOfNonempty) (doOk : Bool) : MetaM α := do + if h : i < unresolved.size then + let mut x := unresolved[i] + let mut xTy ← whnfCore (← inferType x) + let mut changed := false + while xTy.isForall do + if let some domVal ← cache.mkInhabitant? xTy.bindingDomain! then + x := .app x domVal + xTy ← whnfCore (xTy.bindingBody!.instantiate1 domVal) + changed := true + else + break + if changed then + cache.withInhabitant x fun cache => do go (i + 1) cache true + else + go (i + 1) { cache with unresolved := cache.unresolved.push x } doOk else - return some (← mkDefault type) - catch _ => - return none + if doOk then + kOk cache + else + kFail + go 0 { cache with unresolved := #[] } false /-- Find an inhabitant while doing delta unfolding. -/ -private partial def mkInhabitantForAux? (xs insts : Array Expr) (type : Expr) (useOfNonempty : Bool) : MetaM (Option Expr) := withIncRecDepth do - if let some val ← mkInhabitant? type useOfNonempty then - mkLambdaFVars xs (← mkLetFVars (usedLetOnly := true) insts val) +private partial def mkInhabitantForAux? (useOfNonempty : Bool) (cache : InhabitantCache useOfNonempty) (type : Expr) : + MetaM (Option Expr) := withIncRecDepth do + let type ← whnfCore type + if type.isForall then + forallTelescope type fun xs type' => + cache.withInhabitants xs fun cache => do + let some val ← mkInhabitantForAux? _ cache type' | return none + mkLambdaFVars xs val + else if let some val ← cache.mkInhabitant? type then + return val else - let type ← whnfCore type - if type.isForall then - forallTelescope type fun xs' type' => - withInhabitedInstances xs' fun insts' => - mkInhabitantForAux? (xs ++ xs') (insts ++ insts') type' useOfNonempty - else if let some type' ← unfoldDefinition? type then - mkInhabitantForAux? xs insts type' useOfNonempty - else - return none + cache.resolve + (kOk := fun cache => mkInhabitantForAux? _ cache type) + (kFail := do + if let some type ← unfoldDefinition? type then + mkInhabitantForAux? _ cache type + else + return none) /- TODO: add a global IO.Ref to let users customize/extend this procedure -/ -def mkInhabitantFor (declName : Name) (xs : Array Expr) (type : Expr) : MetaM Expr := - withInhabitedInstances xs fun insts => do - if let some val ← mkInhabitantForAux? xs insts type false <||> mkInhabitantForAux? xs insts type true then - return val - else - throwError "\ - failed to compile 'partial' definition '{declName}', could not prove that the type\ - {indentExpr (← mkForallFVars xs type)}\n\ - is nonempty.\n\ - \n\ - This process uses multiple strategies:\n\ - - It looks for a parameter that matches the return type.\n\ - - It tries synthesizing '{.ofConstName ``Inhabited}' and '{.ofConstName ``Nonempty}' \ - instances for the return type, while making every parameter into a local '{.ofConstName ``Inhabited}' instance.\n\ - - It tries unfolding the return type.\n\ - \n\ - If the return type is defined using the 'structure' or 'inductive' command, \ - you can try adding a 'deriving Nonempty' clause to it." +def mkInhabitantFor (declName : Name) (type : Expr) : MetaM Expr := do + if let some val ← mkInhabitantForAux? false {} type <||> mkInhabitantForAux? true {} type then + return val + else + throwError "\ + failed to compile 'partial' definition '{declName}', could not prove that the type\ + {indentExpr type}\n\ + is nonempty.\n\ + \n\ + This process uses multiple strategies:\n\ + - It looks for a parameter that matches the return type.\n\ + - It tries synthesizing '{.ofConstName ``Inhabited}' and '{.ofConstName ``Nonempty}' \ + instances for the return type, while making every parameter into a local '{.ofConstName ``Inhabited}' instance.\n\ + - It tries unfolding the return type.\n\ + \n\ + If the return type is defined using the 'structure' or 'inductive' command, \ + you can try adding a 'deriving Nonempty' clause to it." end Lean.Elab diff --git a/tests/lean/run/partialInhabitation.lean b/tests/lean/run/partialInhabitation.lean new file mode 100644 index 000000000000..918b05c30388 --- /dev/null +++ b/tests/lean/run/partialInhabitation.lean @@ -0,0 +1,17 @@ +/-! +# Tests of the metaprogram to prove types are inhabited for `partial` +-/ + +/-! +Example: callbacks that depend on some inhabited arguments. +-/ +partial def callback1 (k : Array Nat → α) : α := callback1 k + +partial def callback2 (k : Array Nat → IO α) : IO α := callback2 k + +partial def callback3 (k : Array Nat → α) : IO α := callback3 k + +/-! +Example: some arguments might depend on results of other functions. +-/ +partial def callback4 (k : Array α → β → α) (f : Unit → β) : IO α := callback4 k f