From eff4bd7e90d0b2daf5ffcb763a8827bd78efc012 Mon Sep 17 00:00:00 2001 From: Roxy Light Date: Tue, 28 Jan 2025 19:59:33 -0800 Subject: [PATCH] Match C Lua's pcall semantics Stop propagating message handler through a protected call. Split off `*State.PCall` from `*State.Call`. --- internal/frontend/eval.go | 14 +-- internal/lua/auxlib.go | 4 +- internal/lua/auxlib_test.go | 2 +- internal/lua/baselib.go | 12 +- internal/lua/baselib_test.go | 8 +- internal/lua/example_test.go | 2 +- internal/lua/lua.go | 108 +++++++++++++--- internal/lua/lua_test.go | 221 +++++++++++++++++++++++++++++++-- internal/lua/stringlib.go | 4 +- internal/lua/stringlib_test.go | 2 +- internal/lua/tablelib.go | 2 +- internal/lua/tablelib_test.go | 2 +- internal/lua/vm.go | 10 +- internal/lua/vm_test.go | 38 +++--- 14 files changed, 355 insertions(+), 74 deletions(-) diff --git a/internal/frontend/eval.go b/internal/frontend/eval.go index aadf860..85aa05c 100644 --- a/internal/frontend/eval.go +++ b/internal/frontend/eval.go @@ -132,7 +132,7 @@ func NewEval(storeDir zbstore.Directory, store *jsonrpc.Client, cacheDB string) if err := eval.l.Load(strings.NewReader(preludeSource), lua.AbstractSource("(prelude)"), "t"); err != nil { return nil, err } - if err := eval.l.Call(ctx, 0, 0, 0); err != nil { + if err := eval.l.Call(ctx, 0, 0); err != nil { return nil, err } @@ -206,7 +206,7 @@ func (eval *Eval) File(ctx context.Context, exprFile string, attrPaths []string) if err := loadFile(&eval.l, exprFile); err != nil { return nil, err } - if err := eval.l.Call(ctx, 0, 1, -2); err != nil { + if err := eval.l.PCall(ctx, 0, 1, -2); err != nil { return nil, err } return eval.attrPaths(ctx, attrPaths) @@ -232,7 +232,7 @@ func (eval *Eval) Expression(ctx context.Context, expr string, attrPaths []strin if err := loadExpression(&eval.l, expr); err != nil { return nil, err } - if err := eval.l.Call(ctx, 0, 1, -2); err != nil { + if err := eval.l.PCall(ctx, 0, 1, -2); err != nil { return nil, err } return eval.attrPaths(ctx, attrPaths) @@ -261,7 +261,7 @@ func (eval *Eval) attrPaths(ctx context.Context, paths []string) ([]any, error) return result, fmt.Errorf("%s: %v", p, err) } eval.l.PushValue(-2) - if err := eval.l.Call(ctx, 1, 1, 0); err != nil { + if err := eval.l.Call(ctx, 1, 1); err != nil { return result, fmt.Errorf("%s: %v", p, err) } x, err := luaToGo(ctx, &eval.l) @@ -411,7 +411,7 @@ func loadFunction(ctx context.Context, l *lua.State) (int, error) { } l.PushValue(lua.UpvalueIndex(1)) l.Insert(1) - if err := l.Call(ctx, maxLoadArgs, lua.MultipleReturns, 0); err != nil { + if err := l.Call(ctx, maxLoadArgs, lua.MultipleReturns); err != nil { return 0, err } return l.Top(), nil @@ -532,7 +532,7 @@ func dofileFunction(ctx context.Context, l *lua.State) (int, error) { // loadfile(filename) l.PushValue(lua.UpvalueIndex(1)) l.Insert(1) - if err := l.Call(ctx, 1, 2, 0); err != nil { + if err := l.Call(ctx, 1, 2); err != nil { return 0, fmt.Errorf("dofile: %v", err) } if l.IsNil(-2) { @@ -542,7 +542,7 @@ func dofileFunction(ctx context.Context, l *lua.State) (int, error) { l.Pop(1) // Call the loaded function. - if err := l.Call(ctx, 0, lua.MultipleReturns, 0); err != nil { + if err := l.Call(ctx, 0, lua.MultipleReturns); err != nil { return 0, fmt.Errorf("dofile %s: %v", resolved, err) } return l.Top(), nil diff --git a/internal/lua/auxlib.go b/internal/lua/auxlib.go index b1d321e..27d710e 100644 --- a/internal/lua/auxlib.go +++ b/internal/lua/auxlib.go @@ -53,7 +53,7 @@ func CallMeta(ctx context.Context, l *State, obj int, event string) (bool, error return false, nil } l.PushValue(obj) - if err := l.Call(ctx, 1, 1, 0); err != nil { + if err := l.Call(ctx, 1, 1); err != nil { return true, fmt.Errorf("lua: call metafield %q: %w", event, err) } return true, nil @@ -371,7 +371,7 @@ func Require(ctx context.Context, l *State, modName string, global bool, openf F l.Pop(1) // remove field l.PushClosure(0, openf) l.PushString(modName) - if err := l.Call(ctx, 1, 1, 0); err != nil { + if err := l.Call(ctx, 1, 1); err != nil { return fmt.Errorf("lua: require %q: %w", modName, err) } l.PushValue(-1) diff --git a/internal/lua/auxlib_test.go b/internal/lua/auxlib_test.go index 47d0bab..1c5a1da 100644 --- a/internal/lua/auxlib_test.go +++ b/internal/lua/auxlib_test.go @@ -73,7 +73,7 @@ func TestWhere(t *testing.T) { if err := state.Load(strings.NewReader(test.luaCode), chunkName, "t"); err != nil { t.Fatal(err) } - if err := state.Call(ctx, 0, 1, 0); err != nil { + if err := state.Call(ctx, 0, 1); err != nil { t.Fatal(err) } diff --git a/internal/lua/baselib.go b/internal/lua/baselib.go index 6a9ce05..b36dc70 100644 --- a/internal/lua/baselib.go +++ b/internal/lua/baselib.go @@ -118,7 +118,7 @@ func newBaseDofile(loadfile Function) Function { // loadfile(filename) l.PushClosure(0, loadfile) l.Insert(1) - if err := l.Call(ctx, 1, 2, 0); err != nil { + if err := l.Call(ctx, 1, 2); err != nil { return 0, err } if l.IsNil(-2) { @@ -128,7 +128,7 @@ func newBaseDofile(loadfile Function) Function { l.Pop(1) // Call the loaded function. - if err := l.Call(ctx, 0, MultipleReturns, 0); err != nil { + if err := l.Call(ctx, 0, MultipleReturns); err != nil { return 0, err } return l.Top(), nil @@ -358,7 +358,7 @@ func basePairs(ctx context.Context, l *State) (int, error) { } if Metafield(l, 1, "__pairs") != TypeNil { l.PushValue(1) // self for metamethod - if err := l.Call(ctx, 1, 3, 0); err != nil { + if err := l.Call(ctx, 1, 3); err != nil { return 0, err } return 3, nil @@ -378,7 +378,7 @@ func basePCall(ctx context.Context, l *State) (int, error) { l.PushBoolean(true) l.Insert(1) - if err := l.Call(ctx, l.Top()-2, MultipleReturns, 0); err != nil { + if err := l.PCall(ctx, l.Top()-2, MultipleReturns, 0); err != nil { l.PushBoolean(false) // TODO(someday): Push error object from err. l.PushString(err.Error()) @@ -408,7 +408,7 @@ func baseXPCall(ctx context.Context, l *State) (int, error) { l.PushBoolean(true) l.Rotate(3, 1) - if err := l.Call(ctx, numArgs, MultipleReturns, 1); err != nil { + if err := l.PCall(ctx, numArgs, MultipleReturns, 1); err != nil { l.PushBoolean(false) // TODO(someday): Push error object from err. l.PushString(err.Error()) @@ -647,7 +647,7 @@ func (r *luaFunctionReader) ReadByte() (byte, error) { } r.s, r.i = "", 0 // Prevent unreading. r.l.PushValue(1) - r.err = r.l.Call(r.ctx, 0, 1, 0) + r.err = r.l.Call(r.ctx, 0, 1) if r.err != nil { return 0, r.err } diff --git a/internal/lua/baselib_test.go b/internal/lua/baselib_test.go index a8d746a..b68f8cd 100644 --- a/internal/lua/baselib_test.go +++ b/internal/lua/baselib_test.go @@ -27,7 +27,7 @@ func TestAssert(t *testing.T) { } state.PushBoolean(true) - if err := state.Call(ctx, 1, MultipleReturns, 0); err != nil { + if err := state.Call(ctx, 1, MultipleReturns); err != nil { t.Fatal(err) } @@ -57,7 +57,7 @@ func TestAssert(t *testing.T) { } state.PushBoolean(false) - if err := state.Call(ctx, 1, MultipleReturns, 0); err == nil { + if err := state.Call(ctx, 1, MultipleReturns); err == nil { t.Error("state.Call(ctx, 1, MultipleReturns, 0) did not return an error") } else if got, want := err.Error(), "assertion failed!"; got != want { t.Errorf("state.Call(ctx, 1, MultipleReturns, 0) error = %q; want %q", got, want) @@ -82,7 +82,7 @@ func TestAssert(t *testing.T) { } state.PushNil() - if err := state.Call(ctx, 1, MultipleReturns, 0); err == nil { + if err := state.Call(ctx, 1, MultipleReturns); err == nil { t.Error("state.Call(ctx, 1, MultipleReturns, 0) did not return an error") } else if got, want := err.Error(), "assertion failed!"; got != want { t.Errorf("state.Call(ctx, 1, MultipleReturns, 0) error = %q; want %q", got, want) @@ -109,7 +109,7 @@ func TestAssert(t *testing.T) { state.PushBoolean(false) const msg = "bork bork bork" state.PushString(msg) - if err := state.Call(ctx, 2, MultipleReturns, 0); err == nil { + if err := state.Call(ctx, 2, MultipleReturns); err == nil { t.Error("state.Call(ctx, 1, MultipleReturns, 0) did not return an error") } else if got, want := err.Error(), msg; got != want { t.Errorf("state.Call(ctx, 1, MultipleReturns, 0) error = %q; want %q", got, want) diff --git a/internal/lua/example_test.go b/internal/lua/example_test.go index 3aad007..43292ed 100644 --- a/internal/lua/example_test.go +++ b/internal/lua/example_test.go @@ -29,7 +29,7 @@ func Example() { if err := state.Load(strings.NewReader(luaSource), luaSource, "t"); err != nil { log.Fatal(err) } - if err := state.Call(ctx, 0, 0, 0); err != nil { + if err := state.Call(ctx, 0, 0); err != nil { log.Fatal(err) } // Output: diff --git a/internal/lua/lua.go b/internal/lua/lua.go index 8c5ff18..7a5a3f7 100644 --- a/internal/lua/lua.go +++ b/internal/lua/lua.go @@ -1431,7 +1431,7 @@ func (l *State) SetUserValue(idx int, n int) bool { return true } -// Call calls a function (or callable object) in protected mode. +// Call calls a function (or callable object). // // To do a call you must use the following protocol: // first, the function to be called is pushed onto the stack; @@ -1454,20 +1454,79 @@ func (l *State) SetUserValue(idx int, n int) bool { // // # Error Handling // +// If an error occurs during the function call, +// it is returned as a Go error value. +// (This is in contrast to the C Lua API, which longjmps to the last protected call.) +// If a caller used [*State.PCall] to set a message handler, +// then the message handler is called before unwinding the stack +// and before Call returns. +func (l *State) Call(ctx context.Context, nArgs, nResults int) error { + if nArgs < 0 { + return errors.New("negative argument count") + } + if nResults < 0 && nResults != MultipleReturns { + return errors.New("negative result count") + } + if l.Top() < nArgs+1 { + return errMissingArguments + } + l.init() + if nResults != MultipleReturns && cap(l.stack)-len(l.stack) < nResults-nArgs { + l.Pop(nArgs + 1) + return fmt.Errorf("results from function overflow current stack size") + } + + isLua, err := l.prepareCall(ctx, len(l.stack)-nArgs-1, callOptions{ + numResults: nResults, + }) + if err != nil { + return err + } + if isLua { + if err := l.exec(ctx); err != nil { + return err + } + } + return nil +} + +// PCall calls a function (or callable object) in protected mode. +// +// To do a call you must use the following protocol: +// first, the function to be called is pushed onto the stack; +// then, the arguments to the call are pushed in direct order; +// that is, the first argument is pushed first. +// Finally you call PCall; +// nArgs is the number of arguments that you pushed onto the stack. +// When the function returns, +// all arguments and the function value are popped +// and the call results are pushed onto the stack. +// The number of results is adjusted to nResults, +// unless nResults is [MultipleReturns]. +// In this case, all results from the function are pushed; +// Lua takes care that the returned values fit into the stack space, +// but it does not ensure any extra space in the stack. +// The function results are pushed onto the stack in direct order +// (the first result is pushed first), +// so that after the call the last result is on the top of the stack. +// PCall always removes the function and its arguments from the stack. +// +// # Error Handling +// // If the msgHandler argument is 0, // then if an error occurs during the function call, // it is returned as a Go error value. // (This is in contrast to the C Lua API which pushes an error object onto the stack.) // Otherwise, msgHandler is the stack index of a message handler. // In case of runtime errors, this handler will be called with the error object -// and Call will push its return value onto the stack. -// The return value's string value will be used as a Go error returned by Call. +// and PCall will push its return value onto the stack. +// The return value's string value will be used as a Go error returned by PCall. // // Typically, the message handler is used to add more debug information to the error object, // such as a stack traceback. // Such information cannot be gathered after the return of a [State] method, // since by then the stack will have been unwound. -func (l *State) Call(ctx context.Context, nArgs, nResults, msgHandler int) error { +func (l *State) PCall(ctx context.Context, nArgs, nResults, msgHandler int) error { if nArgs < 0 { return errors.New("negative argument count") } @@ -1497,7 +1556,11 @@ func (l *State) Call(ctx context.Context, nArgs, nResults, msgHandler int) error } } - isLua, err := l.prepareCall(ctx, len(l.stack)-nArgs-1, nResults, false, msgHandlerFunction) + isLua, err := l.prepareCall(ctx, len(l.stack)-nArgs-1, callOptions{ + numResults: nResults, + protected: true, + messageHandler: msgHandlerFunction, + }) if err != nil { if msgHandler != 0 { l.push(l.errorToValue(err)) @@ -1526,7 +1589,7 @@ func (l *State) call(ctx context.Context, numResults int, f value, args ...value functionIndex := len(l.stack) l.stack = append(l.stack, f) l.stack = append(l.stack, args...) - isLua, err := l.prepareCall(ctx, functionIndex, numResults, false, nil) + isLua, err := l.prepareCall(ctx, functionIndex, callOptions{numResults: numResults}) if err != nil { return err } @@ -1550,6 +1613,14 @@ func (l *State) call1(ctx context.Context, f value, args ...value) (value, error return v, nil } +// callOptions holds optional arguments to [*State.prepareCall]. +type callOptions struct { + numResults int + isTailCall bool + protected bool + messageHandler function +} + // prepareCall pushes a new [callFrame] onto l.callStack // to start executing a new function. // The caller must have pushed the function to call @@ -1579,11 +1650,14 @@ func (l *State) call1(ctx context.Context, f value, args ...value) (value, error // If prepareCall calls a Go function and it returns an error, // it will call the message handler (if any) // before popping the Go function's frame off the call stack. -func (l *State) prepareCall(ctx context.Context, functionIndex, numResults int, isTailCall bool, messageHandler function) (isLua bool, err error) { +func (l *State) prepareCall(ctx context.Context, functionIndex int, opts callOptions) (isLua bool, err error) { var nextMessageHandler *messageHandlerState - if messageHandler != nil { - nextMessageHandler = &messageHandlerState{function: messageHandler} - } else { + switch { + case opts.messageHandler != nil: + nextMessageHandler = &messageHandlerState{function: opts.messageHandler} + case opts.protected: + nextMessageHandler = nil + default: nextMessageHandler = l.frame().messageHandler } @@ -1596,8 +1670,8 @@ func (l *State) prepareCall(ctx context.Context, functionIndex, numResults int, } newFrame := callFrame{ functionIndex: functionIndex, - numResults: numResults, - isTailCall: isTailCall, + numResults: opts.numResults, + isTailCall: opts.isTailCall, messageHandler: nextMessageHandler, } if !l.grow(newFrame.registerStart() + int(f.proto.MaxStackSize)) { @@ -1613,7 +1687,7 @@ func (l *State) prepareCall(ctx context.Context, functionIndex, numResults int, newFrame.numExtraArguments = numExtraArguments } } - if isTailCall { + if opts.isTailCall { // Move function and arguments up to the frame pointer. frame := l.frame() fp := frame.framePointer() @@ -1639,7 +1713,7 @@ func (l *State) prepareCall(ctx context.Context, functionIndex, numResults int, l.callStack = append(l.callStack, callFrame{ functionIndex: functionIndex, - numResults: numResults, + numResults: opts.numResults, messageHandler: nextMessageHandler, }) n, err := f.cb(ctx, l) @@ -1659,7 +1733,7 @@ func (l *State) prepareCall(ctx context.Context, functionIndex, numResults int, // and pop call frames. newStackTop := functionIndex newCallStackTop := len(l.callStack) - 1 - if isTailCall { + if opts.isTailCall { newCallStackTop-- newStackTop = l.callStack[newStackTop].framePointer() } @@ -1668,7 +1742,7 @@ func (l *State) prepareCall(ctx context.Context, functionIndex, numResults int, return false, err } - if isTailCall { + if opts.isTailCall { // Pop the Go function stack frame // so that finishCall will move the results to the caller's caller's frame. l.popCallStack() @@ -1927,7 +2001,7 @@ func (l *State) concatMetamethod(ctx context.Context) error { rotate(l.stack[len(l.stack)-3:], 1) // Call metamethod. - isLua, err := l.prepareCall(ctx, len(l.stack)-3, 1, false, nil) + isLua, err := l.prepareCall(ctx, len(l.stack)-3, callOptions{numResults: 1}) if err != nil { return err } diff --git a/internal/lua/lua_test.go b/internal/lua/lua_test.go index ebe3d40..977c32e 100644 --- a/internal/lua/lua_test.go +++ b/internal/lua/lua_test.go @@ -81,7 +81,7 @@ func TestLoad(t *testing.T) { if err := state.Load(strings.NewReader(source), source, "t"); err != nil { t.Fatal(err) } - if err := state.Call(ctx, 0, 1, 0); err != nil { + if err := state.Call(ctx, 0, 1); err != nil { t.Fatal(err) } if !state.IsNumber(-1) { @@ -115,7 +115,7 @@ func TestLoad(t *testing.T) { if err := state.Load(bytes.NewReader(chunk), "", "b"); err != nil { t.Fatal(err) } - if err := state.Call(ctx, 0, 1, 0); err != nil { + if err := state.Call(ctx, 0, 1); err != nil { t.Fatal(err) } if !state.IsNumber(-1) { @@ -158,7 +158,7 @@ func TestLoad(t *testing.T) { if err := state.Load(bytes.NewReader(test.data), source, "bt"); err != nil { t.Fatal(err) } - if err := state.Call(ctx, 0, 1, 0); err != nil { + if err := state.Call(ctx, 0, 1); err != nil { t.Fatal(err) } if !state.IsNumber(-1) { @@ -494,7 +494,7 @@ func TestCompare(t *testing.T) { state.PushValue(-3) state.PushValue(-3) - if err := state.Call(ctx, 2, 1, 0); err != nil { + if err := state.Call(ctx, 2, 1); err != nil { t.Logf("(%s %v %s): %v", s1, op, s2, err) if want != bad { t.Fail() @@ -740,7 +740,7 @@ func TestMessageHandler(t *testing.T) { t.Fatal(err) } - if err := state.Call(ctx, 0, 0, -2); err == nil { + if err := state.PCall(ctx, 0, 0, -2); err == nil { t.Error("Running script did not return an error") } else if got, want := err.Error(), handledMessage; got != want { t.Errorf("state.Call(...).Error() = %q; want %q", got, want) @@ -831,7 +831,7 @@ func TestMessageHandler(t *testing.T) { t.Fatal(err) } - if err := state.Call(ctx, 0, 0, -2); err == nil { + if err := state.PCall(ctx, 0, 0, -2); err == nil { t.Error("Running script did not return an error") } else if got, want := err.Error(), handledMessage; got != want { t.Errorf("state.Call(...).Error() = %q; want %q", got, want) @@ -930,7 +930,7 @@ func TestMessageHandler(t *testing.T) { t.Fatal(err) } - if err := state.Call(ctx, 0, 0, -2); err == nil { + if err := state.PCall(ctx, 0, 0, -2); err == nil { t.Error("Running script did not return an error") } else if got, want := err.Error(), handledMessage; got != want { t.Errorf("state.Call(...).Error() = %q; want %q", got, want) @@ -955,6 +955,209 @@ func TestMessageHandler(t *testing.T) { } } }) + + t.Run("CrossesCall", func(t *testing.T) { + ctx := context.Background() + state := new(State) + defer func() { + if err := state.Close(); err != nil { + t.Error("Close:", err) + } + }() + + const originalMessage = "bork" + state.PushClosure(0, func(ctx context.Context, l *State) (int, error) { + // Call a Go function using the Lua API that raises an error. + l.PushClosure(0, func(ctx context.Context, l *State) (int, error) { + return 0, errors.New(originalMessage) + }) + if err := l.Call(ctx, 0, 0); err != nil { + return 0, err + } + t.Error("Calling error-raising function did not raise an error.") + return 0, nil + }) + if err := state.SetGlobal(ctx, "foo"); err != nil { + t.Fatal(err) + } + + // Message handler. + const handledMessage = "uwu" + messageHandlerCallCount := 0 + state.PushClosure(0, func(ctx context.Context, l *State) (int, error) { + messageHandlerCallCount++ + + if got, want := l.Top(), 1; got != want { + t.Errorf("l.Top() = %d; want %d", got, want) + } + if got, want := l.Type(1), TypeString; got != want { + t.Errorf("l.Type(1) = %v; want %v", got, want) + } else if got, ok := l.ToString(1); got != originalMessage || !ok { + t.Errorf("l.ToString(1) = %q, %t; want %q, true", got, ok, originalMessage) + } + + errFuncDebug := l.Info(1) + if errFuncDebug == nil { + t.Error("l.Info(1) = nil") + } else { + if got, want := errFuncDebug.What, "Go"; got != want { + t.Errorf("l.Info(1).What = %q; want %q", got, want) + } + } + + goFuncDebug := l.Info(2) + if goFuncDebug == nil { + t.Error("l.Info(2) = nil") + } else { + if got, want := goFuncDebug.What, "Go"; got != want { + t.Errorf("l.Info(2).What = %q; want %q", got, want) + } + } + + scriptDebug := l.Info(3) + if scriptDebug == nil { + t.Error("l.Info(3) = nil") + } else { + if got, want := scriptDebug.What, "main"; got != want { + t.Errorf("l.Info(3).What = %q; want %q", got, want) + } + if got, want := scriptDebug.CurrentLine, 2; got != want { + t.Errorf("l.Info(3).CurrentLine = %d; want %d", got, want) + } + } + + l.PushString(handledMessage) + return 1, nil + }) + + const source = `-- Comment here to advance line number.` + "\n" + + `foo()` + "\n" + if err := state.Load(strings.NewReader(source), LiteralSource(source), "t"); err != nil { + t.Fatal(err) + } + + if err := state.PCall(ctx, 0, 0, -2); err == nil { + t.Error("Running script did not return an error") + } else if got, want := err.Error(), handledMessage; got != want { + t.Errorf("state.Call(...).Error() = %q; want %q", got, want) + } + if messageHandlerCallCount != 1 { + t.Errorf("message handler called %d times; want 1", messageHandlerCallCount) + } + + const errorObjectIndex = 2 + if got, want := state.Top(), errorObjectIndex; got != want { + t.Errorf("after state.Call(...), state.Top() = %d; want %d", got, want) + } + if state.Top() >= errorObjectIndex { + if got, want := state.Type(errorObjectIndex), TypeString; got != want { + t.Errorf("after state.Call(...), state.Type(%d) = %v; want %v", errorObjectIndex, got, want) + } else { + got, ok := state.ToString(errorObjectIndex) + want := handledMessage + if !ok || got != want { + t.Errorf("after state.Call(...), state.ToString(%d) = %q, %t; want %q, true", errorObjectIndex, got, ok, want) + } + } + } + }) + + t.Run("DoesNotCrossPCall", func(t *testing.T) { + ctx := context.Background() + state := new(State) + defer func() { + if err := state.Close(); err != nil { + t.Error("Close:", err) + } + }() + + const originalMessage = "bork" + state.PushClosure(0, func(ctx context.Context, l *State) (int, error) { + // Call a Go function using the Lua API that raises an error. + l.PushClosure(0, func(ctx context.Context, l *State) (int, error) { + return 0, errors.New(originalMessage) + }) + if err := l.PCall(ctx, 0, 0, 0); err != nil { + return 0, err + } + t.Error("Calling error-raising function did not raise an error.") + return 0, nil + }) + if err := state.SetGlobal(ctx, "foo"); err != nil { + t.Fatal(err) + } + + // Message handler. + const handledMessage = "uwu" + messageHandlerCallCount := 0 + state.PushClosure(0, func(ctx context.Context, l *State) (int, error) { + messageHandlerCallCount++ + + if got, want := l.Top(), 1; got != want { + t.Errorf("l.Top() = %d; want %d", got, want) + } + if got, want := l.Type(1), TypeString; got != want { + t.Errorf("l.Type(1) = %v; want %v", got, want) + } else if got, ok := l.ToString(1); got != originalMessage || !ok { + t.Errorf("l.ToString(1) = %q, %t; want %q, true", got, ok, originalMessage) + } + + errFuncDebug := l.Info(1) + if errFuncDebug == nil { + t.Error("l.Info(1) = nil") + } else { + if got, want := errFuncDebug.What, "Go"; got != want { + t.Errorf("l.Info(1).What = %q; want %q", got, want) + } + } + + scriptDebug := l.Info(2) + if scriptDebug == nil { + t.Error("l.Info(2) = nil") + } else { + if got, want := scriptDebug.What, "main"; got != want { + t.Errorf("l.Info(2).What = %q; want %q", got, want) + } + if got, want := scriptDebug.CurrentLine, 2; got != want { + t.Errorf("l.Info(2).CurrentLine = %d; want %d", got, want) + } + } + + l.PushString(handledMessage) + return 1, nil + }) + + const source = `-- Comment here to advance line number.` + "\n" + + `foo()` + "\n" + if err := state.Load(strings.NewReader(source), LiteralSource(source), "t"); err != nil { + t.Fatal(err) + } + + if err := state.PCall(ctx, 0, 0, -2); err == nil { + t.Error("Running script did not return an error") + } else if got, want := err.Error(), handledMessage; got != want { + t.Errorf("state.Call(...).Error() = %q; want %q", got, want) + } + if messageHandlerCallCount != 1 { + t.Errorf("message handler called %d times; want 1", messageHandlerCallCount) + } + + const errorObjectIndex = 2 + if got, want := state.Top(), errorObjectIndex; got != want { + t.Errorf("after state.Call(...), state.Top() = %d; want %d", got, want) + } + if state.Top() >= errorObjectIndex { + if got, want := state.Type(errorObjectIndex), TypeString; got != want { + t.Errorf("after state.Call(...), state.Type(%d) = %v; want %v", errorObjectIndex, got, want) + } else { + got, ok := state.ToString(errorObjectIndex) + want := handledMessage + if !ok || got != want { + t.Errorf("after state.Call(...), state.ToString(%d) = %q, %t; want %q, true", errorObjectIndex, got, ok, want) + } + } + } + }) } func TestRotate(t *testing.T) { @@ -1012,7 +1215,7 @@ func TestSuite(t *testing.T) { if err != nil { t.Fatal(err) } - if err := l.Call(ctx, 0, 0, 0); err != nil { + if err := l.Call(ctx, 0, 0); err != nil { t.Fatal(err) } }) @@ -1033,7 +1236,7 @@ func BenchmarkExec(b *testing.B) { if err := state.Load(strings.NewReader(source), source, "t"); err != nil { b.Fatal(err) } - if err := state.Call(ctx, 0, 1, 0); err != nil { + if err := state.Call(ctx, 0, 1); err != nil { b.Fatal(err) } state.Pop(1) diff --git a/internal/lua/stringlib.go b/internal/lua/stringlib.go index 15f8f6f..49ba432 100644 --- a/internal/lua/stringlib.go +++ b/internal/lua/stringlib.go @@ -683,7 +683,7 @@ func gsubFunction(ctx context.Context, l *State, state *gsubState, match []int) if err != nil { return false, err } - if err := l.Call(ctx, n, 1, 0); err != nil { + if err := l.Call(ctx, n, 1); err != nil { return false, err } if !l.ToBoolean(-1) { @@ -1562,7 +1562,7 @@ func stringArithmetic(ctx context.Context, l *State, op luacode.ArithmeticOperat Where(l, 1), mtName[2:], l.Type(-2), l.Type(-1)) } l.Insert(-3) - if err := l.Call(ctx, 2, 1, 0); err != nil { + if err := l.Call(ctx, 2, 1); err != nil { return 0, err } return 1, nil diff --git a/internal/lua/stringlib_test.go b/internal/lua/stringlib_test.go index 3378bff..6beb31c 100644 --- a/internal/lua/stringlib_test.go +++ b/internal/lua/stringlib_test.go @@ -310,7 +310,7 @@ func TestStringGSub(t *testing.T) { funcIndex := state.Top() test.push(state) - if err := state.Call(ctx, state.Top()-funcIndex, 2, 0); err != nil { + if err := state.Call(ctx, state.Top()-funcIndex, 2); err != nil { t.Fatal("gsub:", err) } diff --git a/internal/lua/tablelib.go b/internal/lua/tablelib.go index d5d403d..6091642 100644 --- a/internal/lua/tablelib.go +++ b/internal/lua/tablelib.go @@ -385,7 +385,7 @@ func (ts *tableSorter) Less(i, j int) bool { return i < j } if hasCompareFunction { - ts.err = ts.l.Call(ts.ctx, 2, 1, 0) + ts.err = ts.l.Call(ts.ctx, 2, 1) if ts.err != nil { return i < j } diff --git a/internal/lua/tablelib_test.go b/internal/lua/tablelib_test.go index 1b62f0f..94f1b7d 100644 --- a/internal/lua/tablelib_test.go +++ b/internal/lua/tablelib_test.go @@ -93,7 +93,7 @@ func TestTableSort(t *testing.T) { state.PushClosure(0, test.compare) } - if err := state.Call(ctx, state.Top()-funcIndex, 0, 0); err != nil { + if err := state.Call(ctx, state.Top()-funcIndex, 0); err != nil { t.Error("table.sort:", err) } diff --git a/internal/lua/vm.go b/internal/lua/vm.go index 2657f46..4e4c734 100644 --- a/internal/lua/vm.go +++ b/internal/lua/vm.go @@ -990,7 +990,7 @@ func (l *State) exec(ctx context.Context) (err error) { if numArguments >= 0 { l.setTop(functionIndex + 1 + numArguments) } - isLua, err := l.prepareCall(ctx, functionIndex, numResults, false, nil) + isLua, err := l.prepareCall(ctx, functionIndex, callOptions{numResults: numResults}) if err != nil { return err } @@ -1019,7 +1019,11 @@ func (l *State) exec(ctx context.Context) (err error) { clear(l.stack[registerStart:functionIndex]) varargStart, varargEnd := frame.extraArgumentsRange() clear(l.stack[varargStart:varargEnd]) - if _, err := l.prepareCall(ctx, functionIndex, numResults, true, nil); err != nil { + _, err := l.prepareCall(ctx, functionIndex, callOptions{ + numResults: numResults, + isTailCall: true, + }) + if err != nil { return err } if len(l.callStack) <= callerDepth { @@ -1219,7 +1223,7 @@ func (l *State) exec(ctx context.Context) (err error) { } l.setTop(newTop) copy(l.stack[stateEnd:], l.stack[stateStart:]) - isLua, err := l.prepareCall(ctx, stateEnd, c, false, nil) + isLua, err := l.prepareCall(ctx, stateEnd, callOptions{numResults: c}) if err != nil { return err } diff --git a/internal/lua/vm_test.go b/internal/lua/vm_test.go index f813d11..7b06190 100644 --- a/internal/lua/vm_test.go +++ b/internal/lua/vm_test.go @@ -31,7 +31,7 @@ func TestVM(t *testing.T) { if err := state.Load(strings.NewReader(source), source, "t"); err != nil { t.Fatal(err) } - if err := state.Call(ctx, 0, 1, 0); err != nil { + if err := state.Call(ctx, 0, 1); err != nil { t.Fatal(err) } if !state.IsNumber(-1) { @@ -60,7 +60,7 @@ func TestVM(t *testing.T) { if err := state.Load(strings.NewReader(source), source, "t"); err != nil { t.Fatal(err) } - if err := state.Call(ctx, 0, 1, 0); err != nil { + if err := state.Call(ctx, 0, 1); err != nil { t.Fatal(err) } if !state.IsNumber(-1) { @@ -89,7 +89,7 @@ func TestVM(t *testing.T) { if err := state.Load(strings.NewReader(source), source, "t"); err != nil { t.Fatal(err) } - if err := state.Call(ctx, 0, 1, 0); err != nil { + if err := state.Call(ctx, 0, 1); err != nil { t.Fatal(err) } if !state.IsNumber(-1) { @@ -114,7 +114,7 @@ func TestVM(t *testing.T) { if err := state.Load(strings.NewReader(source), source, "t"); err != nil { t.Fatal(err) } - if err := state.Call(ctx, 0, 1, 0); err != nil { + if err := state.Call(ctx, 0, 1); err != nil { t.Fatal(err) } if !state.IsTable(-1) { @@ -157,7 +157,7 @@ func TestVM(t *testing.T) { if err := state.Load(strings.NewReader(source), Source(source), "t"); err != nil { t.Fatal(err) } - if err := state.Call(ctx, 0, 1, 0); err != nil { + if err := state.Call(ctx, 0, 1); err != nil { t.Fatal(err) } if !state.IsTable(-1) { @@ -188,7 +188,7 @@ func TestVM(t *testing.T) { if err := state.Load(strings.NewReader(source), Source(source), "t"); err != nil { t.Fatal(err) } - if err := state.Call(ctx, 0, 1, 0); err != nil { + if err := state.Call(ctx, 0, 1); err != nil { t.Fatal(err) } if !state.IsNumber(-1) { @@ -213,7 +213,7 @@ func TestVM(t *testing.T) { if err := state.Load(strings.NewReader(source), Source(source), "t"); err != nil { t.Fatal(err) } - if err := state.Call(ctx, 0, 1, 0); err != nil { + if err := state.Call(ctx, 0, 1); err != nil { t.Fatal(err) } if !state.IsNumber(-1) { @@ -243,7 +243,7 @@ func TestVM(t *testing.T) { if err := state.Load(strings.NewReader(source), Source(source), "t"); err != nil { t.Fatal(err) } - if err := state.Call(ctx, 0, 1, 0); err != nil { + if err := state.Call(ctx, 0, 1); err != nil { t.Fatal(err) } const want = "Hello, World" @@ -270,7 +270,7 @@ func TestVM(t *testing.T) { if err := state.Load(strings.NewReader(source), Source(source), "t"); err != nil { t.Fatal(err) } - if err := state.Call(ctx, 0, 1, 0); err != nil { + if err := state.Call(ctx, 0, 1); err != nil { t.Fatal(err) } const want = "Hello, World!" @@ -311,7 +311,7 @@ func TestVM(t *testing.T) { if err := state.Load(strings.NewReader(source), Source(source), "t"); err != nil { t.Fatal(err) } - if err := state.Call(ctx, 0, 0, 0); err != nil { + if err := state.Call(ctx, 0, 0); err != nil { t.Fatal(err) } @@ -372,7 +372,7 @@ func TestVM(t *testing.T) { if err := state.Load(strings.NewReader(source), Source(source), "t"); err != nil { t.Fatal(err) } - if err := state.Call(ctx, 0, 0, 0); err != nil { + if err := state.Call(ctx, 0, 0); err != nil { t.Fatal(err) } @@ -416,7 +416,7 @@ func TestVM(t *testing.T) { if err := state.Load(strings.NewReader(source), Source(source), "t"); err != nil { t.Fatal(err) } - if err := state.Call(ctx, 0, 0, 0); err != nil { + if err := state.Call(ctx, 0, 0); err != nil { t.Fatal(err) } @@ -455,7 +455,7 @@ func TestVM(t *testing.T) { if err := state.Load(strings.NewReader(source), Source(source), "t"); err != nil { t.Fatal(err) } - if err := state.Call(ctx, 0, 0, 0); err != nil { + if err := state.Call(ctx, 0, 0); err != nil { t.Fatal(err) } @@ -494,7 +494,7 @@ func TestVM(t *testing.T) { if err := state.Load(strings.NewReader(source), Source(source), "t"); err != nil { t.Fatal(err) } - if err := state.Call(ctx, 0, 0, 0); err != nil { + if err := state.Call(ctx, 0, 0); err != nil { t.Fatal(err) } @@ -570,7 +570,7 @@ func TestVM(t *testing.T) { if err := state.Load(strings.NewReader(source), Source(source), "t"); err != nil { t.Fatal(err) } - if err := state.Call(ctx, 0, 0, 0); err != nil { + if err := state.Call(ctx, 0, 0); err != nil { t.Fatal(err) } @@ -598,7 +598,7 @@ func TestVM(t *testing.T) { if err := state.Load(strings.NewReader(source), Source(source), "t"); err != nil { t.Fatal(err) } - if err := state.Call(ctx, 0, 1, 0); err != nil { + if err := state.Call(ctx, 0, 1); err != nil { t.Fatal(err) } if !state.IsNumber(-1) { @@ -632,7 +632,7 @@ func TestVM(t *testing.T) { if err := state.Load(strings.NewReader(source), Source(source), "t"); err != nil { t.Fatal(err) } - if err := state.Call(ctx, 0, 1, 0); err != nil { + if err := state.Call(ctx, 0, 1); err != nil { t.Fatal(err) } if got, want := state.Type(-1), TypeString; got != want { @@ -677,7 +677,7 @@ func TestVM(t *testing.T) { if err := state.Load(strings.NewReader(source), Source(source), "t"); err != nil { t.Fatal(err) } - if err := state.Call(ctx, 0, 1, 0); err != nil { + if err := state.Call(ctx, 0, 1); err != nil { t.Fatal(err) } const wantResult = 85 @@ -704,7 +704,7 @@ func TestVM(t *testing.T) { if err := state.Load(strings.NewReader(source), Source(source), "t"); err != nil { t.Fatal(err) } - if err := state.Call(ctx, 0, 1, 0); err != nil { + if err := state.Call(ctx, 0, 1); err != nil { t.Fatal(err) }