Skip to content

Commit

Permalink
Tail call optimization tricks in Norvig's lispy.py.
Browse files Browse the repository at this point in the history
  • Loading branch information
rread committed Sep 10, 2015
1 parent fb0fc0e commit 4ca1d4b
Show file tree
Hide file tree
Showing 3 changed files with 164 additions and 139 deletions.
23 changes: 23 additions & 0 deletions env.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,29 @@ func NewEnv(outer *Env) *Env {

}

func ExtendEnv(names []Symbol, values Data, outer *Env) (*Env, error) {
env := NewEnv(outer)
if len(names) != listLen(values) {
return nil, fmt.Errorf("parameter mismatch %v != %v", names, values)
}
for _, name := range names {

val := car(values)

if err := getError(val); err != nil {
return nil, err
}
env.Bind(name, val)

values = cdr(values)
if err := getError(val); err != nil {
return nil, err
}

}
return env, nil
}

func (e *Env) BindName(name string, i Data) {
sym := internSymbol(name)
e.vars[sym] = i
Expand Down
257 changes: 119 additions & 138 deletions eval.go
Original file line number Diff line number Diff line change
Expand Up @@ -178,121 +178,132 @@ var (

func eval(expr Data, env *Env) (Data, error) {
log.Printf("eval: %T: %v\n", expr, expr)
switch e := expr.(type) {
case Boolean:
return e, nil
case Symbol:
return env.FindVar(e)
case Number:
return e, nil
case String:
return e, nil
case Null:
return e, nil
case *Pair:
c, _ := getSymbol(car(e))
/* non-Symbols fall through to default */
switch c {
case _quote:
return cadr(e), nil
case _define:
expr, err := getPair(cdr(e))
if err != nil {
return nil, err
}
err = definition(expr, env)
if err != nil {
return nil, err
}
// Return value of define is undefined
return _ok, nil
case _set:
d, err := getSymbol(cadr(e))
if err != nil {
return nil, err
}
val, err := eval(caddr(e), env)
if err != nil {
return nil, err
}
env.Find(d).Bind(d, val)
return nil, nil
case _if:
test, _ := eval(cadr(e), env)
if isTrue(test) {
return eval(caddr(e), env)
} else if listLen(e) > 3 {
return eval(cadddr(e), env)
}
return Empty, nil
case _let:
letExpr, err := let(cdr(e), env)
if err != nil {
return nil, err
}
return eval(letExpr, env)
case _begin:
return evalSequential(cdr(e), env)
case _quit:
os.Exit(0)
case _lambda:
params, err := getList(cadr(e))
if err != nil {
return nil, fmt.Errorf("bad params: %v", err)
}
body, err := getList(cddr(e))
if err != nil {
return nil, fmt.Errorf("bad body: %v", err)
}
return evalLambda(params, body, env)
case _vars:
for k, v := range env.vars {
log.Printf("%v: %v\n", k, v)
for {
switch e := expr.(type) {
case Boolean:
return e, nil
case Symbol:
return env.FindVar(e)
case Number:
return e, nil
case String:
return e, nil
case Null:
return e, nil
case *Pair:
c, _ := getSymbol(car(e))
/* non-Symbols fall through to default */
switch c {
case _quote:
return cadr(e), nil
case _define:
expr, err := getPair(cdr(e))
if err != nil {
return nil, err
}
err = definition(expr, env)
if err != nil {
return nil, err
}
// Return value of define is undefined
return _ok, nil
case _set:
d, err := getSymbol(cadr(e))
if err != nil {
return nil, err
}
val, err := eval(caddr(e), env)
if err != nil {
return nil, err
}
env.Find(d).Bind(d, val)
return nil, nil
case _if:
test, _ := eval(cadr(e), env)
if isTrue(test) {
expr = caddr(e)
} else if listLen(e) > 3 {
expr = cadddr(e)
} else {
return Empty, nil
}
case _let:
letExpr, err := let(cdr(e), env)
if err != nil {
return nil, err
}
return eval(letExpr, env)
case _begin:
e, err := getPair(cdr(e))
if err != nil {
return nil, err
}
for !nullp(e) {
if nullp(cdr(e)) {
expr = car(e)
break
}
_, err = eval(car(e), env)
if err != nil {
return nil, err
}
e, err = listNext(e)
if err != nil {
return nil, err
}

}

case _quit:
os.Exit(0)
case _lambda:
params, err := getList(cadr(e))
if err != nil {
return nil, fmt.Errorf("bad params: %v", err)
}
body, err := getList(cddr(e))
if err != nil {
return nil, fmt.Errorf("bad body: %v", err)
}
return evalLambda(params, body, env)
case _vars:
for k, v := range env.vars {
log.Printf("%v: %v\n", k, v)
}
return nil, nil
default:
log.Printf("procedure call %v", e)
proc, err := eval(car(e), env)
if err != nil {
return nil, err
}
args, err := evalArgs(cdr(e), env)
if err != nil {
return nil, err
}
switch f := proc.(type) {
case InternalFunc:
return f(args)
case *Lambda:
var err error
env, err = ExtendEnv(f.params, args, f.envt)
if err != nil {
return nil, err
}
expr = cons(_begin, f.body)
default:
return nil, fmt.Errorf("apply to a non function: %#v %v", proc, args)
}

}
case nil:
log.Fatal("parsed a nil?")
return nil, nil
default:
log.Printf("procedure call %v", e)
proc, err := eval(car(e), env)
if err != nil {
return nil, err
}
args, err := evalArgs(cdr(e), env)
if err != nil {
return nil, err
}
return apply(proc, args, env)
}
case nil:
log.Fatal("parsed a nil?")
return nil, nil
}
return nil, fmt.Errorf("Unparsable expression: %v", expr)
}

func evalSequential(d Data, env *Env) (Data, error) {
var v Data
e, err := getPair(d)
if err != nil {
return nil, err
}
for !nullp(e) {
log.Printf("begin: %v", e)
v, err = eval(car(e), env)
if err != nil {
return nil, err
}
if nullp(cdr(e)) {
break
}
e, err = listNext(e)
if err != nil {
return nil, err
}

}
return v, nil

}
func definition(defn *Pair, env *Env) error {
var value Data
var name Symbol
Expand Down Expand Up @@ -437,36 +448,6 @@ func let(expr Data, env *Env) (Data, error) {
return eval(result, env)
}

func apply(proc Data, args Data, env *Env) (Data, error) {
// log.Printf("apply: %v args: %v\n", proc, args)
switch f := proc.(type) {
case InternalFunc:
return f(args)
case *Lambda:
if len(f.params) != listLen(args) {
return nil, fmt.Errorf("parameter mismatch %v != %v", f.params, args)
}
for _, name := range f.params {
var err error
p, err := getPair(args)
if err != nil {
return nil, err
}
f.envt.Bind(name, car(p))
if nullp(cdr(p)) {
break
}
args = cdr(args)
if err, ok := args.(error); ok {
return nil, err
}
}
return evalSequential(f.body, f.envt)
default:
return nil, fmt.Errorf("apply to a non function: %#v %v", proc, args)
}
}

func replReader(in io.Reader, env *Env) (Data, error) {
// l := NewScanner(in)
buf := make([]byte, 1024)
Expand Down
23 changes: 22 additions & 1 deletion repl_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,27 @@ func S(v interface{}) string {
return fmt.Sprintf("%v", v)
}

var result Data

func BenchmarkFactorial(b *testing.B) {
b.StopTimer()
env := DefaultEnv()
value, err := repl("(define fact (lambda (n) (if (<= n 1) 1 (* n (fact (- n 1))))))", env)
if err != nil {
panic(err)
}
result = value
b.StartTimer()

for n := 0; n < b.N; n++ {
val, err := repl("(fact 100)", env)
if err != nil {
panic(err)
}
result = val
}
}

func TestRepl(t *testing.T) {

Convey("basic lexer testing", t, func() {
Expand Down Expand Up @@ -178,7 +199,7 @@ func TestRepl(t *testing.T) {
So(err, ShouldBeNil)
So(S(val), ShouldEqual, "(1 1)")

_, err = repl("(lambda () (+ 1 1))", env)
_, err = repl("(lambda () (+ 2 3) (+ 1 1))", env)
So(err, ShouldBeNil)

val, err = repl("((lambda () (+ 1 1)))", env)
Expand Down

0 comments on commit 4ca1d4b

Please sign in to comment.