Skip to content

Commit

Permalink
clean up some fixpoint code
Browse files Browse the repository at this point in the history
  • Loading branch information
TimWhiting committed Dec 6, 2024
1 parent adb4a91 commit 851e687
Show file tree
Hide file tree
Showing 8 changed files with 301 additions and 104 deletions.
34 changes: 33 additions & 1 deletion std/data/rb-map.kk
Original file line number Diff line number Diff line change
Expand Up @@ -57,4 +57,36 @@ pub fun kvalue/map(s: rbmap<k,v>, f: (k, v) -> e x): e rbmap<k,x>

pub fun value/map(s: rbmap<k,v>, f: (v) -> e x): e rbmap<k,x>
val Rbmap(tree) = s
Rbmap(tree.map(fn(v) f(v)))
Rbmap(tree.map(fn(v) f(v)))

pub fun (==)(s1: rbmap<k,v>, s2: rbmap<k,v>, ^?k/order2: (k,k) -> e order2<k>, ^?v/order2: (v,v) -> e order2<v>): e bool
match order2(s1, s2)
Eq2 -> True
_ -> False

pub fun order2(orig1: rbmap<k,v>, orig2: rbmap<k,v>, ^?k/order2: (k,k) -> e order2<k>, ^?v/order2: (v,v) -> e order2<v>): e order2<rbmap<k,v>>
match orig1
Rbmap(t1) ->
match orig2
Rbmap(t2) -> order2(t1, t2, orig1, orig2)

fun rec/order2(t1: rbtree<k,v>, t2: rbtree<k,v>, orig1: rbmap<k,v>, orig2: rbmap<k,v>, ^?k/order2: (k,k) -> e order2<k>, ^?v/order2: (v,v) -> e order2<v>): e order2<rbmap<k,v>>
match t1
Leaf -> match t2
Leaf -> Eq2(orig1)
Node -> Lt2(orig1, orig2)
Node(_,l1,k1,v1,r1) -> match t2
Leaf -> Gt2(orig2, orig1)
Node(_,l2,k2,v2,r2) ->
match order2(k1, k2)
Lt2 -> Lt2(orig1, orig2)
Gt2 -> Gt2(orig2, orig1)
Eq2 ->
match order2(v1, v2)
Lt2 -> Lt2(orig1, orig2)
Gt2 -> Gt2(orig2, orig1)
Eq2 ->
match order2(l1, l2, orig1, orig2)
Lt2 -> Lt2(orig1, orig2)
Gt2 -> Gt2(orig2, orig1)
Eq2 -> order2(r1, r2, orig1, orig2)
19 changes: 18 additions & 1 deletion std/data/rb-set.kk
Original file line number Diff line number Diff line change
Expand Up @@ -34,4 +34,21 @@ pub fun list(s: rbset<k>) : list<k>

pub fun show(s: rbset<k>, ?k/show: k -> e string): e string
val Rbset(tree) = s
tree.keys.show
tree.keys.show

pub fun order2(t1: rbtree<k,()>, t2: rbtree<k,()>, orig1: rbset<k>, orig2: rbset<k>, ^?order2: (k,k) -> e order2<k>): e order2<rbset<k>>
match t1
Leaf -> match t2
Leaf -> Eq2(orig1)
Node -> Lt2(orig1, orig2)
Node(_,l1,k1,_,r1) -> match t2
Leaf -> Gt2(orig2, orig1)
Node(_,l2,k2,_,r2) ->
match order2(k1, k2)
Lt2 -> Lt2(orig1, orig2)
Gt2 -> Gt2(orig2, orig1)
Eq2 ->
match order2(l1, l2, orig1, orig2)
Lt2 -> Lt2(orig1, orig2)
Gt2 -> Gt2(orig2, orig1)
Eq2 -> order2(r1, r2, orig1, orig2)
16 changes: 7 additions & 9 deletions std/data/rbtree-bu.kk
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ pub fip fun balance( z : zipper<k,v>, t : root<k,v> ) : rbtree<k,v>
Done -> Node(Black, t.to-node, k1, v1, r1)
z -> rebuild(z, t.to-node)

pub fip(1) fun zip/set(t : rbtree<k,v>, key : k, v : v, z : zipper<k,v>, ^?order2: (k,k) -> e order2<k>) : e rbtree<k,v>
pub fbip(1) fun zip/set(t : rbtree<k,v>, key : k, v : v, z : zipper<k,v>, ^?order2: (k,k) -> e order2<k>) : e rbtree<k,v>
match t
Node(c, l, kx, vx, r) ->
match order2(key, kx)
Expand All @@ -35,13 +35,13 @@ pub fip(1) fun zip/set(t : rbtree<k,v>, key : k, v : v, z : zipper<k,v>, ^?order
Eq2(ki) -> rebuild(z, Node(c, l, ki, v, r)) // Actually override the value, no balancing needed
Leaf -> balance(z, Root(Red, Leaf, key, v, Leaf)) // Insert a new node and balance

pub fip(1) fun zip/add(t : rbtree<k,v>, key : k, v : v, z : zipper<k,v>, ^?order2: (k,k) -> e order2<k>) : e rbtree<k,v>
pub fbip(1) fun zip/add(t : rbtree<k,v>, key : k, v : v, z : zipper<k,v>, ^?order2: (k,k) -> e order2<k>) : e rbtree<k,v>
match t
Node(c, l, kx, vx, r) ->
match order2(key, kx)
Lt2(ki, kj) -> add(l, ki, v, ZNodeL(c, z, kj, vx, r))
Gt2(kj, ki) -> add(r, ki, v, ZNodeR(c, l, kj, vx, z))
Eq2 -> rebuild(z, Node(c,l,kx,vx,r)) // No overriding
Eq2(kx') -> rebuild(z, Node(c,l,kx',vx,r)) // No overriding
Leaf -> balance(z, Root(Red, Leaf, key, v, Leaf)) // Insert a new node and balance

// Take a function that is called with Just the old value if it exists, or Nothing if it doesn't
Expand All @@ -51,16 +51,14 @@ pub fip(1) fun zip/insert(t : rbtree<k,v>, key : k, z : zipper<k,v>, ^f: (maybe<
match order2(key, kx)
Lt2(ki, kj) -> insert(l, ki, ZNodeL(c, z, kj, vx, r), f)
Gt2(kj, ki) -> insert(r, ki, ZNodeR(c, l, kj, vx, z), f)
Eq2 -> rebuild(z, Node(c,l,kx,f(Just(vx)),r)) // No overriding
Eq2(kx') -> rebuild(z, Node(c,l,kx',f(Just(vx)),r)) // No overriding
Leaf -> balance(z, Root(Red, Leaf, key, f(Nothing), Leaf)) // Insert a new node and balance

pub fip(1) fun bu/set(t: rbtree<k,v>, key: k, value: v, ^?order2: (k,k) -> e order2<k>): e rbtree<k,v>
pub fbip(1) fun bu/set(t: rbtree<k,v>, key: k, value: v, ^?order2: (k,k) -> e order2<k>): e rbtree<k,v>
t.set(key, value, Done)

pub fip(1) fun bu/add(t: rbtree<k,v>, key: k, value: v, ^?order2: (k,k) -> e order2<k>): e rbtree<k,v>
pub fbip(1) fun bu/add(t: rbtree<k,v>, key: k, value: v, ^?order2: (k,k) -> e order2<k>): e rbtree<k,v>
t.add(key, value, Done)

pub fip(1) fun bu/insert(t: rbtree<k,v>, key: k, ^f: (maybe<v>) -> e v, ^?order2: (k,k) -> e order2<k>): e rbtree<k,v>
pub fbip(1) fun bu/insert(t: rbtree<k,v>, key: k, f: (maybe<v>) -> e v, ^?order2: (k,k) -> e order2<k>): e rbtree<k,v>
t.insert(key, Done, f)


6 changes: 3 additions & 3 deletions std/data/rbtree.kk
Original file line number Diff line number Diff line change
Expand Up @@ -36,17 +36,17 @@ pub reference type dhole // define as a reference type so the derivative tree
pub fun list(t: rbtree<k,v>) : list<(k,v)>
match t
Leaf -> []
Node(_,l,k,v,r) -> Cons((k,v), l.list ++ r.list)
Node(_,l,k,v,r) -> l.list ++ Cons((k,v), r.list)

pub fun keys(t: rbtree<k,v>) : list<k>
match t
Leaf -> []
Node(_,l,k,_,r) -> Cons(k, l.keys ++ r.keys)
Node(_,l,k,_,r) -> l.keys ++ Cons(k, r.keys)

pub fun values(t: rbtree<k,v>) : list<v>
match t
Leaf -> []
Node(_,l,_,v,r) -> Cons(v, l.values ++ r.values)
Node(_,l,_,v,r) -> l.values ++ Cons(v, r.values)

pub inline fun empty() : rbtree<k,v>
Leaf
Expand Down
87 changes: 43 additions & 44 deletions std/fixpoint/fixpoint-memo.kk
Original file line number Diff line number Diff line change
@@ -1,70 +1,69 @@
pub import std/data/rb-map
import std/fixpoint/lattice

effect cache<s,c>
fun add-result(s: s, r: c): ()
fun add-state(s: s): ()
fun is-cached(s: s): bool
// k = key, r = result
effect cache<k,r>
fun add-result(key: k, result: r): ()
fun is-cached(key: k): bool
ctl do-each(ss: list<a>): a
final ctl none(): b
ctl depend(s: s): c
ctl depend(key: k): r

fun cache(f: () -> <pure,cache<s,c>|e> b, ?order2: (s, s) -> pure order2<s>, .?change-lattice:change-lattice<r,c>, ?s/show: s -> string, ?c/show: c -> string): <pure|e> rbmap<s,r>
var m : some<s,r> rbmap<s,r> := empty()
var deps : some<s,c,e> rbmap<s, list<(c -> <pure|e> ())>> := empty()
fun update(s, c)
match deps.lookup(s)
// A fixpoint cache handler
fun cache(comp: () -> <pure,cache<k,r>|e> d, ?order2: (k, k) -> pure order2<k>,
.?change-lattice:change-lattice<b,r>, ?k/show: k -> string, ?r/show: r -> string): <pure|e> rbmap<k,b>
var cache : some<k,r> rbmap<k,r> := empty()
var deps : some<k,c,e> rbmap<k, list<(c -> <pure|e> ())>> := empty()
fun update(key, change)
match deps.lookup(key)
Just(resumes) ->
// trace("Updating " ++ resumes.length.show ++ " deps for " ++ s.show ++ " with " ++ c.c/show)
resumes.list/foreach(fn(res) {res(c); ()})
resumes.list/foreach(fn(resumption) {resumption(change); ()})
Nothing -> ()
val do =
with handler
fun add-state(s)
m := m.set(s, bottom)
fun add-result(s, c)
fun add-result(key, change)
// trace("Adding result for " ++ s.s/show ++ " " ++ c.c/show)
match m.lookup(s)
match cache.lookup(key)
Just(r') ->
val (changed, r'') = join(r', c)
val (changed, r'') = join(r', change)
if changed then
m := m.set(s, r'')
update(s, c)
cache := cache.set(key, r'')
update(key, change)
else ()
Nothing ->
m := m.set(s, bottom.join(c).snd)
update(s, c)
fun is-cached(s)
m.contains(s)
ctl depend(s)
cache := cache.set(key, bottom.join(change).snd)
update(key, change)
fun is-cached(key)
if cache.contains(key) || deps.contains(key) then True
else
cache := cache.set(key, bottom)
deps := deps.set(key, Nil)
False
ctl depend(key)
// trace("Adding dep for " ++ s.s/show)
match deps.lookup(s)
match deps.lookup(key)
Just(resumes) ->
val ress = Cons(fn(r) resume(r), resumes)
deps := deps.set(s, ress)
Nothing ->
deps := deps.set(s, [fn(r) resume(r)])
match m.lookup(s)
Just(r) ->
r.changes.foreach(fn(c) update(s, c))
Nothing -> ()
deps := deps.set(key, ress)
match cache.lookup(key)
Just(r) -> r.changes.foreach(fn(change) resume(change))
ctl do-each(ss)
ss.foreach(fn(s) resume(s))
final ctl none()
()
return(x) ()
f()
m
comp()
cache

fun memo(s, f)
match is-cached(s)
True -> depend(s)
False ->
add-state(s)
f(fn(ls) fix/each(s, ls))
// Inserts a memoization point at a recursive invocation
fun memo(key : k, func : ((list<() -> <cache<k,r>|e> r>) -> <cache<k,r>|e> r) -> <cache<k,r>|e> r): <cache<k,r>|e> r
match is-cached(key)
True -> depend(key)
False -> func(fn(ls) fix/each(key, ls))

fun fix/each(s: s, ls: list<(() -> <cache<s,c>|e> c)>): <cache<s,c>|e> c
val f = do-each(ls)
val r = f()
add-result(s,r)
r
fun fix/each(key: k, ls: list<(() -> <cache<k,r>|e> r)>): <cache<k,r>|e> r
val func = do-each(ls)
val result = func()
add-result(key, result)
result
64 changes: 63 additions & 1 deletion std/fixpoint/lattice.kk
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
module std/fixpoint/lattice

import std/data/rb-set

pub struct change-lattice<r,c>
Expand All @@ -17,4 +18,65 @@ pub fun set/join(old: rbset<a>, new: a, ?order2: (a, a) -> pure order2<a>): pure
(change, l')

pub fun set/change-lattice(?order2: (a, a) -> pure order2<a>): change-lattice<rbset<a>, a>
Change-lattice(rb-set/empty(), fn(r, c) set/join(r, c), fn(x) x.list)
Change-lattice(rb-set/empty(), fn(r, c) set/join(r, c), fn(x) x.list)

pub value type simple-lattice<a>
LValue(a: a)
LTop
LBot

pub fun order2(a: simple-lattice<a>, b: simple-lattice<a>, ?order2: (a, a) -> order2<a>): pure order2<simple-lattice<a>>
match (a, b)
(LValue(a'), LValue(b')) ->
match order2(a', b')
Lt2 -> Lt2(a, b)
Gt2 -> Gt2(b, a)
Eq2 -> Eq2(a)
(LTop, LTop) -> Eq2(LTop)
(LTop, _) -> Gt2(LTop, b)
(_, LTop) -> Lt2(a, LTop)
(LBot, LBot) -> Eq2(LBot)
(LBot, _) -> Lt2(LBot, b)
(_, LBot) -> Gt2(a, LBot)

pub fun show(s: simple-lattice<a>, ?show: a -> string): string
match s
LValue(a) -> "LValue(" ++ show(a) ++ ")"
LTop -> "LTop"
LBot -> "LBot"

pub fun (==)(a: simple-lattice<a>, b: simple-lattice<a>, ?(==): (a, a) -> pure bool): pure bool
match (a, b)
(LValue(a'), LValue(b')) -> a' == b'
(LTop, LTop) -> True
(LBot, LBot) -> True
_ -> False

pub fun simple/join(a: simple-lattice<a>, b: simple-lattice<a>, ?(==): (a, a) -> pure bool): pure (bool, simple-lattice<a>)
match (a, b)
(LTop, _) -> (False, LTop)
(_, LTop) -> (True, LTop)
(LBot, x) -> (True, x)
(x, LBot) -> (False, x)
(LValue(a'), LValue(b')) -> if a' == b' then (False, a) else (True, LTop)

pub fun simple/change-lattice((==): (a, a) -> pure bool): pure change-lattice<simple-lattice<a>, simple-lattice<a>>
Change-lattice(LBot, fn(r, c) simple/join(r, c), fn(x) [x])

pub value type product-change<a,b>
CLeft(a: a)
CRight(b: b)

pub fun product/change-lattice(l: change-lattice<r1,c1>, r: change-lattice<r2,c2>): change-lattice<(r1,r2), product-change<c1,c2>>
Change-lattice((l.bottom, r.bottom),
(fn((l1,r1), c)
match c
CLeft(c1) ->
val (change, l2) = (l.join)(l1, c1)
return (change, (l2, r1))
CRight(c2) ->
val (change, r2) = (r.join)(r1, c2)
return (change, (l1, r2))
),
fn((l1,r1)) (l.changes)(l1).map(fn(x) CLeft(x)) ++ (r.changes)(r1).map(fn(x) CRight(x))
)
Loading

0 comments on commit 851e687

Please sign in to comment.