Skip to content

Commit

Permalink
tests: add Lean formalization for AVL
Browse files Browse the repository at this point in the history
Signed-off-by: Ryan Lahfa <[email protected]>
  • Loading branch information
Ryan Lahfa committed Jul 1, 2024
1 parent ce6bcd5 commit 4a89dbd
Show file tree
Hide file tree
Showing 10 changed files with 984 additions and 0 deletions.
7 changes: 7 additions & 0 deletions tests/lean/Avl.lean
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
import Avl.Extracted
-- import Avl.Insert ; TODO: the insert_loop does not exist anymore!
import Avl.Find
import Avl.Order
import Avl.Height
import Avl.Rotate
import Avl.Rebalance.Complete
23 changes: 23 additions & 0 deletions tests/lean/Avl/AVL.lean
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
import Avl.Extracted
import Avl.Tree

open Tree (AVLTree)

namespace Tree

variable {T: Type}

open avl_verification

def AVLTree.balancingFactor (t: AVLTree T): ℤ := match t with
| .none => 0
| .some (AVLNode.mk _ left right _) => AVLTree.height left - AVLTree.height right

lemma AVLTree.balancingFactor_eq (t: AVLTree T): AVLTree.balancingFactor t = AVLTree.height (AVLTree.left t) - AVLTree.height (AVLTree.right t) := by sorry

@[simp]
lemma AVLTree.balancingFactor_some (left right: AVLTree T): AVLTree.balancingFactor (some (AVLNode.mk x left right h)) = AVLTree.height left - AVLTree.height right := by rfl

def AVLTree.isAVL (t: AVLTree T): Prop := |t.balancingFactor| <= 1

end Tree
127 changes: 127 additions & 0 deletions tests/lean/Avl/BinarySearchTree.lean
Original file line number Diff line number Diff line change
@@ -0,0 +1,127 @@
import Avl.Tree

namespace BST

open Primitives (Result Scalar)
open avl_verification (AVLNode Ordering)
open Tree (AVLTree AVLNode.left AVLNode.right AVLNode.val)

-- TODO: build a function to build a tree out of a left and right
-- with automatic height computation?

@[reducible]
def AVLNode.mk' (a: T) (left: AVLTree T) (right: AVLTree T): AVLNode T :=
let height := 1 + max left.height right.height
-- TODO: Scalar.ofNat would be nice...
-- discharge the proof by bounding left & right's height all the time... ?
-- Interesting remark:
-- Lean can support trees of arbitrary height
-- Rust cannot because the height computation will overflow at some point, Rust can only do with trees with representable (usize) height.
-- It's not big deal because the maximally sized tree is bigger than what modern computer can store at all (exabyte-sized tree).
AVLNode.mk a left right (@Scalar.ofInt _ height (by sorry))

inductive ForallNode (p: T -> Prop): AVLTree T -> Prop
| none : ForallNode p none
| some (a: T) (left: AVLTree T) (right: AVLTree T) : ForallNode p left -> p a -> ForallNode p right -> ForallNode p (some (AVLNode.mk a left right h))

theorem ForallNode.left {p: T -> Prop} {t: AVLTree T}: ForallNode p t -> ForallNode p t.left := by
intro Hpt
cases Hpt with
| none => simp [AVLTree.left, ForallNode.none]
| some a left right f_pleft f_pa f_pright => simp [AVLTree.left, f_pleft]

theorem ForallNode.right {p: T -> Prop} {t: AVLTree T}: ForallNode p t -> ForallNode p t.right := by
intro Hpt
cases Hpt with
| none => simp [AVLTree.right, ForallNode.none]
| some a left right f_pleft f_pa f_pright => simp [AVLTree.right, f_pright]

theorem ForallNode.label {a: T} {p: T -> Prop} {left right: AVLTree T}: ForallNode p (AVLNode.mk a left right h) -> p a := by
intro Hpt
cases Hpt with
| some a left right f_pleft f_pa f_pright => exact f_pa

theorem ForallNode.not_mem {a: T} (p: T -> Prop) (t: Option (AVLNode T)): ¬ p a -> ForallNode p t -> a ∉ AVLTree.set t := fun Hnpa Hpt => by
cases t with
| none => simp [AVLTree.set]; tauto
| some t =>
cases Hpt with
| some b left right f_pbleft f_pb f_pbright =>
simp [AVLTree.set_some]
push_neg
split_conjs
. by_contra hab; rw [hab] at Hnpa; exact Hnpa f_pb
. exact ForallNode.not_mem p left Hnpa f_pbleft
. exact ForallNode.not_mem p right Hnpa f_pbright

theorem ForallNode.not_mem' {a: T} (p: T -> Prop) (t: Option (AVLNode T)): p a -> ForallNode (fun x => ¬p x) t -> a ∉ AVLTree.set t := fun Hpa Hnpt => by
refine' ForallNode.not_mem (fun x => ¬ p x) t _ _
simp [Hpa]
exact Hnpt

theorem ForallNode.imp {p q: T -> Prop} {t: AVLTree T}: (∀ x, p x -> q x) -> ForallNode p t -> ForallNode q t := fun Himp Hpt => by
induction Hpt
. simp [ForallNode.none]
. constructor
. assumption
. apply Himp; assumption
. assumption

-- This is the binary search invariant.
variable [LinearOrder T]
inductive Invariant: AVLTree T -> Prop
| none : Invariant none
| some (a: T) (left: AVLTree T) (right: AVLTree T) :
ForallNode (fun v => v < a) left -> ForallNode (fun v => a < v) right
-> Invariant left -> Invariant right -> Invariant (some (AVLNode.mk a left right h))

@[simp]
theorem singleton_bst {a: T}: Invariant (some (AVLNode.mk a none none h)) := by
apply Invariant.some
all_goals simp [ForallNode.none, Invariant.none]

theorem left {t: AVLTree T}: Invariant t -> Invariant t.left := by
intro H
induction H with
| none => exact Invariant.none
| some _ _ _ _ _ _ _ _ _ => simp [AVLTree.left]; assumption

theorem right {t: AVLTree T}: Invariant t -> Invariant t.right := by
intro H
induction H with
| none => exact Invariant.none
| some _ _ _ _ _ _ _ _ _ => simp [AVLTree.right]; assumption

-- TODO: ask at most for LT + Irreflexive (lt_irrefl) + Trichotomy (le_of_not_lt)?
theorem left_pos {left right: Option (AVLNode T)} {a x: T}: BST.Invariant (some (AVLNode.mk a left right h)) -> x ∈ AVLTree.set (AVLNode.mk a left right h) -> x < a -> x ∈ AVLTree.set left := fun Hbst Hmem Hxa => by
simp [AVLTree.set_some] at Hmem
rcases Hmem with (Heq | Hleft) | Hright
. rewrite [Heq] at Hxa; exact absurd Hxa (lt_irrefl _)
. assumption
. exfalso

-- Hbst -> x ∈ right -> ForallNode (fun v => ¬ v < a)
refine' ForallNode.not_mem' (fun v => v < a) right Hxa _ _
simp [le_of_not_lt]
cases Hbst with
| some _ _ _ _ Hforall _ =>
refine' ForallNode.imp _ Hforall
exact fun x => le_of_lt
assumption

theorem right_pos {left right: Option (AVLNode T)} {a x: T}: BST.Invariant (some (AVLNode.mk a left right h)) -> x ∈ AVLTree.set (AVLNode.mk a left right h) -> a < x -> x ∈ AVLTree.set right := fun Hbst Hmem Hax => by
simp [AVLTree.set_some] at Hmem
rcases Hmem with (Heq | Hleft) | Hright
. rewrite [Heq] at Hax; exact absurd Hax (lt_irrefl _)
. exfalso
refine' ForallNode.not_mem' (fun v => a < v) left Hax _ _
simp [le_of_not_lt]
cases Hbst with
| some _ _ _ Hforall _ _ =>
refine' ForallNode.imp _ Hforall
exact fun x => le_of_lt
assumption
. assumption


end BST
48 changes: 48 additions & 0 deletions tests/lean/Avl/Find.lean
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
import Avl.Tree
import Avl.BinarySearchTree
import Avl.Specifications
import Avl.Extracted

namespace Implementation

open Primitives
open avl_verification
open Tree (AVLTree AVLTree.set)
open Specifications (OrdSpecLinearOrderEq infallible ltOfRustOrder gtOfRustOrder)

variable (T: Type) (H: avl_verification.Ord T) [DecidableEq T] [LinearOrder T] (Ospec: OrdSpecLinearOrderEq H)

@[pspec]
def AVLTreeSet.find_loop_spec
(a: T) (t: Option (AVLNode T)):
BST.Invariant t -> ∃ b, AVLTreeSet.find_loop _ H a t = Result.ok b ∧ (b ↔ a ∈ AVLTree.set t) := fun Hbst => by
rewrite [AVLTreeSet.find_loop]
match t with
| none => use false; simp [AVLTree.set]; tauto
| some (AVLNode.mk b left right _) =>
dsimp only
have : ∀ a b, ∃ o, H.cmp a b = .ok o := infallible H
progress keep Hordering as ⟨ ordering ⟩
cases ordering
all_goals dsimp only
. convert (AVLTreeSet.find_loop_spec a right (BST.right Hbst)) using 4
apply Iff.intro
-- We apply a localization theorem here.
. intro Hmem; exact (BST.right_pos Hbst Hmem (ltOfRustOrder _ _ _ Hordering))
. intro Hmem; simp [AVLTree.set_some]; right; assumption
. simp [Ospec.equivalence _ _ Hordering]
. convert (AVLTreeSet.find_loop_spec a left (BST.left Hbst)) using 4
apply Iff.intro
-- We apply a localization theorem here.
. intro Hmem; exact (BST.left_pos Hbst Hmem (gtOfRustOrder _ _ _ Hordering))
. intro Hmem; simp [AVLTree.set_some]; left; right; assumption


def AVLTreeSet.find_spec
(a: T) (t: AVLTreeSet T):
BST.Invariant t.root -> ∃ b, t.find _ H a = Result.ok b ∧ (b ↔ a ∈ AVLTree.set t.root) := fun Hbst => by
rw [AVLTreeSet.find]
progress; simp only [Result.ok.injEq, exists_eq_left']; assumption

end Implementation

160 changes: 160 additions & 0 deletions tests/lean/Avl/Height.lean
Original file line number Diff line number Diff line change
@@ -0,0 +1,160 @@
import Avl.Extracted
import Avl.Tree
import Avl.BinarySearchTree

namespace Implementation

variable {T: Type}

open avl_verification
open Primitives
open Tree (AVLTree AVLNode.left AVLNode.right AVLTree.height_node AVLNode.memoized_height AVLNode.height_left_lt_tree AVLNode.height_right_lt_tree)
open BST (AVLNode.mk')

variable (t: AVLNode T) [O: LinearOrder T] (Tcopy: core.marker.Copy T) (H: avl_verification.Ord T)


instance (ty: ScalarTy) : InBounds ty 0 where
hInBounds := by
induction ty <;> simp [Scalar.cMax, Scalar.cMin, Scalar.max, Scalar.min] <;> decide

theorem Scalar.zero_le_unsigned {ty} (s: ¬ ty.isSigned) (x: Scalar ty): Scalar.ofInt 0 ≤ x := by
apply (Scalar.le_equiv _ _).2
convert x.hmin
cases ty <;> simp [ScalarTy.isSigned] at s <;> simp [Scalar.min]

@[simp]
theorem Scalar.max_unsigned_left_zero_eq {ty} [s: Fact (¬ ty.isSigned)] (x: Scalar ty):
Max.max (Scalar.ofInt 0) x = x := max_eq_right (Scalar.zero_le_unsigned s.out x)

@[simp]
theorem Scalar.max_unsigned_right_zero_eq {ty} [s: Fact (¬ ty.isSigned)] (x: Scalar ty):
Max.max x (Scalar.ofInt 0) = x := max_eq_left (Scalar.zero_le_unsigned s.out x)

@[ext]
theorem Scalar.ext {ty} (a b: Scalar ty): a.val = b.val -> a = b := (Scalar.eq_equiv a b).2

@[pspec]
def max_spec {a b: T}: ∃ o, avl_verification.max _ H Tcopy a b = .ok o ∧ o = O.max a b := by sorry

@[pspec]
def AVLNode.left_height_spec
(left: AVLNode T): (AVLNode.mk x (some left) right h).left_height = left.height
:= by simp only [AVLNode.left_height]

@[pspec]
def AVLNode.right_height_spec
(right: AVLNode T): (AVLNode.mk x left (some right) h).right_height = right.height
:= by simp only [AVLNode.right_height]

@[simp, norm_cast]
theorem coe_max {ty: ScalarTy} (a b: Scalar ty): ↑(Max.max a b) = (Max.max (↑a) (↑b): ℤ) := by
-- TODO: there should be a shorter way to prove this.
rw [max_def, max_def]
split_ifs <;> simp_all
refine' absurd _ (lt_irrefl a)
exact lt_of_le_of_lt (by assumption) ((Scalar.lt_equiv _ _).2 (by assumption))

-- TODO:
@[pspec]
def AVLNode.height_spec (t: AVLNode T): AVLTree.height_node t ≤ Scalar.max .Usize -> ∃ v, t.height = .ok v ∧ v.val = AVLTree.height_node t
:= by
haveI: Fact (¬ ScalarTy.isSigned .Usize) := ⟨by simp [ScalarTy.isSigned]⟩
intro Hbound
simp [AVLNode.height]
match t with
| AVLNode.mk x left right h =>
rcases Hleft: left with _ | ⟨ a, left_left, left_right, h_left ⟩ <;> rcases Hright: right with _ | ⟨ b, right_left, right_right, h_right ⟩ <;> simp only [AVLNode.left_height,
AVLNode.right_height, bind_tc_ok, max_self, Nat.cast_add,
Nat.cast_one]
-- (none, none) case.
. progress with max_spec as ⟨ w, Hw ⟩
simp only [Hw, max_self, AVLTree.height_node_of_mk, Nat.cast_add, Nat.cast_one]
use 1#usize; norm_cast
-- (none, some .) case.
. progress with height_spec as ⟨ w, Hw ⟩
. push_cast
refine' le_trans _ Hbound
apply le_of_lt; rw [Hright]
exact_mod_cast AVLNode.height_right_lt_tree _
. progress with max_spec as ⟨ M, Hm ⟩
rw [Hm]
have: 1 + w.val ≤ Scalar.max .Usize := by
rw [Hw]
refine' le_trans _ Hbound
conv =>
rhs
rw [Hright, AVLTree.height_node, AVLTree.height]
push_cast
refine' Int.add_le_add_left _ _
exact Int.le_max_right _ _
simp only [Scalar.max_unsigned_left_zero_eq, ge_iff_le, zero_le, max_eq_right, Nat.cast_add,
Nat.cast_one]
progress with Usize.add_spec as ⟨ X, Hx ⟩
simp only [Result.ok.injEq, Nat.cast_add,
Nat.cast_one, Nat.cast_max, exists_eq_left', Hx, Scalar.ofInt_val_eq, Hw, add_right_inj]
conv =>
rhs
rw [AVLTree.height_node, AVLTree.height, (max_eq_right (zero_le _)), AVLTree.height]
-- TODO: render invariant by commutativity.
-- (some ., none) case, above.
. sorry
-- (some ., some .) case.
. progress with height_spec as ⟨ c, Hc ⟩
-- TODO: factor me out...
push_cast
refine' le_trans _ Hbound
apply le_of_lt; rw [Hleft]
exact_mod_cast AVLNode.height_left_lt_tree _
progress with height_spec as ⟨ d, Hd ⟩
push_cast
refine' le_trans _ Hbound
apply le_of_lt; rw [Hright]
exact_mod_cast AVLNode.height_right_lt_tree _
progress with max_spec as ⟨ M, Hm ⟩
have: 1 + M.val ≤ Scalar.max .Usize := by
rw [Hm]
refine' le_trans _ Hbound
rw [Hleft, Hright, AVLTree.height_node, AVLTree.height, AVLTree.height]
push_cast
rw [Hc, Hd]
progress with Usize.add_spec as ⟨ X, Hx ⟩
simp [Hx, Hm, Hc, Hd, AVLTree.height]
decreasing_by
all_goals (simp_wf; try simp [Hleft]; try simp [Hright]; try linarith)

-- TODO: discharge all bound requirements
-- by taking (multiple?) hypotheses.
@[pspec]
def AVLNode.update_height_spec (x: T) (h: Usize) (left right: AVLTree T): ∃ t_new, AVLNode.update_height _ (AVLNode.mk x left right h) = .ok t_new ∧ t_new = AVLNode.mk' x left right := by
simp [AVLNode.update_height]
haveI: Fact (¬ ScalarTy.isSigned .Usize) := ⟨by simp [ScalarTy.isSigned]⟩
rcases Hleft: left with _ | ⟨ a, left_left, left_right, h_left ⟩ <;> rcases Hright: right with _ | ⟨ b, right_left, right_right, h_right ⟩ <;> simp [AVLNode.right_height, AVLNode.left_height]
-- TODO: clean up proof structure
-- it's always the same.
. progress with max_spec as ⟨ w, Hw ⟩
rw [Hw];
progress as ⟨ H, H_height ⟩
. simp; scalar_tac
. simp only [Result.ok.injEq, AVLNode.mk.injEq, true_and]; ext; simp [H_height, AVLTree.height]
. progress with height_spec as ⟨ c, Hc ⟩
. sorry
. progress with max_spec as ⟨ w, Hw ⟩
simp at Hw; rw [Hw]; progress as ⟨ H, H_height ⟩; simp; sorry -- 1 + c ≤ Usize.max
simp only [Result.ok.injEq, AVLNode.mk.injEq, true_and]; ext; simp [AVLTree.height, H_height, Hc]
. progress with height_spec as ⟨ c, Hc ⟩
. sorry
. progress with max_spec as ⟨ w, Hw ⟩
progress as ⟨ H, H_height ⟩
. sorry
. simp only [Result.ok.injEq, AVLNode.mk.injEq, true_and]; ext; simp [AVLTree.height, H_height, Hw, Hc]
. progress with height_spec as ⟨ c, Hc ⟩
. sorry
. progress with height_spec as ⟨ d, Hd ⟩
. sorry
. progress with max_spec as ⟨ w, Hw ⟩
progress as ⟨ H, H_height ⟩
. sorry
. simp only [Result.ok.injEq, AVLNode.mk.injEq, true_and]; ext; simp [AVLTree.height, H_height, Hw, Hc, Hd]

end Implementation
Loading

0 comments on commit 4a89dbd

Please sign in to comment.