Skip to content

Commit

Permalink
Match C Lua's pcall semantics
Browse files Browse the repository at this point in the history
Stop propagating message handler through a protected call.
Split off `*State.PCall` from `*State.Call`.
  • Loading branch information
zombiezen committed Jan 29, 2025
1 parent 8708a8d commit eff4bd7
Show file tree
Hide file tree
Showing 14 changed files with 355 additions and 74 deletions.
14 changes: 7 additions & 7 deletions internal/frontend/eval.go
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,7 @@ func NewEval(storeDir zbstore.Directory, store *jsonrpc.Client, cacheDB string)
if err := eval.l.Load(strings.NewReader(preludeSource), lua.AbstractSource("(prelude)"), "t"); err != nil {
return nil, err
}
if err := eval.l.Call(ctx, 0, 0, 0); err != nil {
if err := eval.l.Call(ctx, 0, 0); err != nil {
return nil, err
}

Expand Down Expand Up @@ -206,7 +206,7 @@ func (eval *Eval) File(ctx context.Context, exprFile string, attrPaths []string)
if err := loadFile(&eval.l, exprFile); err != nil {
return nil, err
}
if err := eval.l.Call(ctx, 0, 1, -2); err != nil {
if err := eval.l.PCall(ctx, 0, 1, -2); err != nil {
return nil, err
}
return eval.attrPaths(ctx, attrPaths)
Expand All @@ -232,7 +232,7 @@ func (eval *Eval) Expression(ctx context.Context, expr string, attrPaths []strin
if err := loadExpression(&eval.l, expr); err != nil {
return nil, err
}
if err := eval.l.Call(ctx, 0, 1, -2); err != nil {
if err := eval.l.PCall(ctx, 0, 1, -2); err != nil {
return nil, err
}
return eval.attrPaths(ctx, attrPaths)
Expand Down Expand Up @@ -261,7 +261,7 @@ func (eval *Eval) attrPaths(ctx context.Context, paths []string) ([]any, error)
return result, fmt.Errorf("%s: %v", p, err)
}
eval.l.PushValue(-2)
if err := eval.l.Call(ctx, 1, 1, 0); err != nil {
if err := eval.l.Call(ctx, 1, 1); err != nil {
return result, fmt.Errorf("%s: %v", p, err)
}
x, err := luaToGo(ctx, &eval.l)
Expand Down Expand Up @@ -411,7 +411,7 @@ func loadFunction(ctx context.Context, l *lua.State) (int, error) {
}
l.PushValue(lua.UpvalueIndex(1))
l.Insert(1)
if err := l.Call(ctx, maxLoadArgs, lua.MultipleReturns, 0); err != nil {
if err := l.Call(ctx, maxLoadArgs, lua.MultipleReturns); err != nil {
return 0, err
}
return l.Top(), nil
Expand Down Expand Up @@ -532,7 +532,7 @@ func dofileFunction(ctx context.Context, l *lua.State) (int, error) {
// loadfile(filename)
l.PushValue(lua.UpvalueIndex(1))
l.Insert(1)
if err := l.Call(ctx, 1, 2, 0); err != nil {
if err := l.Call(ctx, 1, 2); err != nil {
return 0, fmt.Errorf("dofile: %v", err)
}
if l.IsNil(-2) {
Expand All @@ -542,7 +542,7 @@ func dofileFunction(ctx context.Context, l *lua.State) (int, error) {
l.Pop(1)

// Call the loaded function.
if err := l.Call(ctx, 0, lua.MultipleReturns, 0); err != nil {
if err := l.Call(ctx, 0, lua.MultipleReturns); err != nil {
return 0, fmt.Errorf("dofile %s: %v", resolved, err)
}
return l.Top(), nil
Expand Down
4 changes: 2 additions & 2 deletions internal/lua/auxlib.go
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ func CallMeta(ctx context.Context, l *State, obj int, event string) (bool, error
return false, nil
}
l.PushValue(obj)
if err := l.Call(ctx, 1, 1, 0); err != nil {
if err := l.Call(ctx, 1, 1); err != nil {
return true, fmt.Errorf("lua: call metafield %q: %w", event, err)
}
return true, nil
Expand Down Expand Up @@ -371,7 +371,7 @@ func Require(ctx context.Context, l *State, modName string, global bool, openf F
l.Pop(1) // remove field
l.PushClosure(0, openf)
l.PushString(modName)
if err := l.Call(ctx, 1, 1, 0); err != nil {
if err := l.Call(ctx, 1, 1); err != nil {
return fmt.Errorf("lua: require %q: %w", modName, err)
}
l.PushValue(-1)
Expand Down
2 changes: 1 addition & 1 deletion internal/lua/auxlib_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ func TestWhere(t *testing.T) {
if err := state.Load(strings.NewReader(test.luaCode), chunkName, "t"); err != nil {
t.Fatal(err)
}
if err := state.Call(ctx, 0, 1, 0); err != nil {
if err := state.Call(ctx, 0, 1); err != nil {
t.Fatal(err)
}

Expand Down
12 changes: 6 additions & 6 deletions internal/lua/baselib.go
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,7 @@ func newBaseDofile(loadfile Function) Function {
// loadfile(filename)
l.PushClosure(0, loadfile)
l.Insert(1)
if err := l.Call(ctx, 1, 2, 0); err != nil {
if err := l.Call(ctx, 1, 2); err != nil {
return 0, err
}
if l.IsNil(-2) {
Expand All @@ -128,7 +128,7 @@ func newBaseDofile(loadfile Function) Function {
l.Pop(1)

// Call the loaded function.
if err := l.Call(ctx, 0, MultipleReturns, 0); err != nil {
if err := l.Call(ctx, 0, MultipleReturns); err != nil {
return 0, err
}
return l.Top(), nil
Expand Down Expand Up @@ -358,7 +358,7 @@ func basePairs(ctx context.Context, l *State) (int, error) {
}
if Metafield(l, 1, "__pairs") != TypeNil {
l.PushValue(1) // self for metamethod
if err := l.Call(ctx, 1, 3, 0); err != nil {
if err := l.Call(ctx, 1, 3); err != nil {
return 0, err
}
return 3, nil
Expand All @@ -378,7 +378,7 @@ func basePCall(ctx context.Context, l *State) (int, error) {
l.PushBoolean(true)
l.Insert(1)

if err := l.Call(ctx, l.Top()-2, MultipleReturns, 0); err != nil {
if err := l.PCall(ctx, l.Top()-2, MultipleReturns, 0); err != nil {
l.PushBoolean(false)
// TODO(someday): Push error object from err.
l.PushString(err.Error())
Expand Down Expand Up @@ -408,7 +408,7 @@ func baseXPCall(ctx context.Context, l *State) (int, error) {
l.PushBoolean(true)
l.Rotate(3, 1)

if err := l.Call(ctx, numArgs, MultipleReturns, 1); err != nil {
if err := l.PCall(ctx, numArgs, MultipleReturns, 1); err != nil {
l.PushBoolean(false)
// TODO(someday): Push error object from err.
l.PushString(err.Error())
Expand Down Expand Up @@ -647,7 +647,7 @@ func (r *luaFunctionReader) ReadByte() (byte, error) {
}
r.s, r.i = "", 0 // Prevent unreading.
r.l.PushValue(1)
r.err = r.l.Call(r.ctx, 0, 1, 0)
r.err = r.l.Call(r.ctx, 0, 1)
if r.err != nil {
return 0, r.err
}
Expand Down
8 changes: 4 additions & 4 deletions internal/lua/baselib_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ func TestAssert(t *testing.T) {
}

state.PushBoolean(true)
if err := state.Call(ctx, 1, MultipleReturns, 0); err != nil {
if err := state.Call(ctx, 1, MultipleReturns); err != nil {
t.Fatal(err)
}

Expand Down Expand Up @@ -57,7 +57,7 @@ func TestAssert(t *testing.T) {
}

state.PushBoolean(false)
if err := state.Call(ctx, 1, MultipleReturns, 0); err == nil {
if err := state.Call(ctx, 1, MultipleReturns); err == nil {
t.Error("state.Call(ctx, 1, MultipleReturns, 0) did not return an error")
} else if got, want := err.Error(), "assertion failed!"; got != want {
t.Errorf("state.Call(ctx, 1, MultipleReturns, 0) error = %q; want %q", got, want)
Expand All @@ -82,7 +82,7 @@ func TestAssert(t *testing.T) {
}

state.PushNil()
if err := state.Call(ctx, 1, MultipleReturns, 0); err == nil {
if err := state.Call(ctx, 1, MultipleReturns); err == nil {
t.Error("state.Call(ctx, 1, MultipleReturns, 0) did not return an error")
} else if got, want := err.Error(), "assertion failed!"; got != want {
t.Errorf("state.Call(ctx, 1, MultipleReturns, 0) error = %q; want %q", got, want)
Expand All @@ -109,7 +109,7 @@ func TestAssert(t *testing.T) {
state.PushBoolean(false)
const msg = "bork bork bork"
state.PushString(msg)
if err := state.Call(ctx, 2, MultipleReturns, 0); err == nil {
if err := state.Call(ctx, 2, MultipleReturns); err == nil {
t.Error("state.Call(ctx, 1, MultipleReturns, 0) did not return an error")
} else if got, want := err.Error(), msg; got != want {
t.Errorf("state.Call(ctx, 1, MultipleReturns, 0) error = %q; want %q", got, want)
Expand Down
2 changes: 1 addition & 1 deletion internal/lua/example_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ func Example() {
if err := state.Load(strings.NewReader(luaSource), luaSource, "t"); err != nil {
log.Fatal(err)
}
if err := state.Call(ctx, 0, 0, 0); err != nil {
if err := state.Call(ctx, 0, 0); err != nil {
log.Fatal(err)
}
// Output:
Expand Down
108 changes: 91 additions & 17 deletions internal/lua/lua.go
Original file line number Diff line number Diff line change
Expand Up @@ -1431,7 +1431,7 @@ func (l *State) SetUserValue(idx int, n int) bool {
return true
}

// Call calls a function (or callable object) in protected mode.
// Call calls a function (or callable object).
//
// To do a call you must use the following protocol:
// first, the function to be called is pushed onto the stack;
Expand All @@ -1454,20 +1454,79 @@ func (l *State) SetUserValue(idx int, n int) bool {
//
// # Error Handling
//
// If an error occurs during the function call,
// it is returned as a Go error value.
// (This is in contrast to the C Lua API, which longjmps to the last protected call.)
// If a caller used [*State.PCall] to set a message handler,
// then the message handler is called before unwinding the stack
// and before Call returns.
func (l *State) Call(ctx context.Context, nArgs, nResults int) error {
if nArgs < 0 {
return errors.New("negative argument count")
}
if nResults < 0 && nResults != MultipleReturns {
return errors.New("negative result count")
}
if l.Top() < nArgs+1 {
return errMissingArguments
}
l.init()
if nResults != MultipleReturns && cap(l.stack)-len(l.stack) < nResults-nArgs {
l.Pop(nArgs + 1)
return fmt.Errorf("results from function overflow current stack size")
}

isLua, err := l.prepareCall(ctx, len(l.stack)-nArgs-1, callOptions{
numResults: nResults,
})
if err != nil {
return err
}
if isLua {
if err := l.exec(ctx); err != nil {
return err
}
}
return nil
}

// PCall calls a function (or callable object) in protected mode.
//
// To do a call you must use the following protocol:
// first, the function to be called is pushed onto the stack;
// then, the arguments to the call are pushed in direct order;
// that is, the first argument is pushed first.
// Finally you call PCall;
// nArgs is the number of arguments that you pushed onto the stack.
// When the function returns,
// all arguments and the function value are popped
// and the call results are pushed onto the stack.
// The number of results is adjusted to nResults,
// unless nResults is [MultipleReturns].
// In this case, all results from the function are pushed;
// Lua takes care that the returned values fit into the stack space,
// but it does not ensure any extra space in the stack.
// The function results are pushed onto the stack in direct order
// (the first result is pushed first),
// so that after the call the last result is on the top of the stack.
// PCall always removes the function and its arguments from the stack.
//
// # Error Handling
//
// If the msgHandler argument is 0,
// then if an error occurs during the function call,
// it is returned as a Go error value.
// (This is in contrast to the C Lua API which pushes an error object onto the stack.)
// Otherwise, msgHandler is the stack index of a message handler.
// In case of runtime errors, this handler will be called with the error object
// and Call will push its return value onto the stack.
// The return value's string value will be used as a Go error returned by Call.
// and PCall will push its return value onto the stack.
// The return value's string value will be used as a Go error returned by PCall.
//
// Typically, the message handler is used to add more debug information to the error object,
// such as a stack traceback.
// Such information cannot be gathered after the return of a [State] method,
// since by then the stack will have been unwound.
func (l *State) Call(ctx context.Context, nArgs, nResults, msgHandler int) error {
func (l *State) PCall(ctx context.Context, nArgs, nResults, msgHandler int) error {
if nArgs < 0 {
return errors.New("negative argument count")
}
Expand Down Expand Up @@ -1497,7 +1556,11 @@ func (l *State) Call(ctx context.Context, nArgs, nResults, msgHandler int) error
}
}

isLua, err := l.prepareCall(ctx, len(l.stack)-nArgs-1, nResults, false, msgHandlerFunction)
isLua, err := l.prepareCall(ctx, len(l.stack)-nArgs-1, callOptions{
numResults: nResults,
protected: true,
messageHandler: msgHandlerFunction,
})
if err != nil {
if msgHandler != 0 {
l.push(l.errorToValue(err))
Expand Down Expand Up @@ -1526,7 +1589,7 @@ func (l *State) call(ctx context.Context, numResults int, f value, args ...value
functionIndex := len(l.stack)
l.stack = append(l.stack, f)
l.stack = append(l.stack, args...)
isLua, err := l.prepareCall(ctx, functionIndex, numResults, false, nil)
isLua, err := l.prepareCall(ctx, functionIndex, callOptions{numResults: numResults})
if err != nil {
return err
}
Expand All @@ -1550,6 +1613,14 @@ func (l *State) call1(ctx context.Context, f value, args ...value) (value, error
return v, nil
}

// callOptions holds optional arguments to [*State.prepareCall].
type callOptions struct {
numResults int
isTailCall bool
protected bool
messageHandler function
}

// prepareCall pushes a new [callFrame] onto l.callStack
// to start executing a new function.
// The caller must have pushed the function to call
Expand Down Expand Up @@ -1579,11 +1650,14 @@ func (l *State) call1(ctx context.Context, f value, args ...value) (value, error
// If prepareCall calls a Go function and it returns an error,
// it will call the message handler (if any)
// before popping the Go function's frame off the call stack.
func (l *State) prepareCall(ctx context.Context, functionIndex, numResults int, isTailCall bool, messageHandler function) (isLua bool, err error) {
func (l *State) prepareCall(ctx context.Context, functionIndex int, opts callOptions) (isLua bool, err error) {
var nextMessageHandler *messageHandlerState
if messageHandler != nil {
nextMessageHandler = &messageHandlerState{function: messageHandler}
} else {
switch {
case opts.messageHandler != nil:
nextMessageHandler = &messageHandlerState{function: opts.messageHandler}
case opts.protected:
nextMessageHandler = nil
default:
nextMessageHandler = l.frame().messageHandler
}

Expand All @@ -1596,8 +1670,8 @@ func (l *State) prepareCall(ctx context.Context, functionIndex, numResults int,
}
newFrame := callFrame{
functionIndex: functionIndex,
numResults: numResults,
isTailCall: isTailCall,
numResults: opts.numResults,
isTailCall: opts.isTailCall,
messageHandler: nextMessageHandler,
}
if !l.grow(newFrame.registerStart() + int(f.proto.MaxStackSize)) {
Expand All @@ -1613,7 +1687,7 @@ func (l *State) prepareCall(ctx context.Context, functionIndex, numResults int,
newFrame.numExtraArguments = numExtraArguments
}
}
if isTailCall {
if opts.isTailCall {
// Move function and arguments up to the frame pointer.
frame := l.frame()
fp := frame.framePointer()
Expand All @@ -1639,7 +1713,7 @@ func (l *State) prepareCall(ctx context.Context, functionIndex, numResults int,

l.callStack = append(l.callStack, callFrame{
functionIndex: functionIndex,
numResults: numResults,
numResults: opts.numResults,
messageHandler: nextMessageHandler,
})
n, err := f.cb(ctx, l)
Expand All @@ -1659,7 +1733,7 @@ func (l *State) prepareCall(ctx context.Context, functionIndex, numResults int,
// and pop call frames.
newStackTop := functionIndex
newCallStackTop := len(l.callStack) - 1
if isTailCall {
if opts.isTailCall {
newCallStackTop--
newStackTop = l.callStack[newStackTop].framePointer()
}
Expand All @@ -1668,7 +1742,7 @@ func (l *State) prepareCall(ctx context.Context, functionIndex, numResults int,
return false, err
}

if isTailCall {
if opts.isTailCall {
// Pop the Go function stack frame
// so that finishCall will move the results to the caller's caller's frame.
l.popCallStack()
Expand Down Expand Up @@ -1927,7 +2001,7 @@ func (l *State) concatMetamethod(ctx context.Context) error {
rotate(l.stack[len(l.stack)-3:], 1)

// Call metamethod.
isLua, err := l.prepareCall(ctx, len(l.stack)-3, 1, false, nil)
isLua, err := l.prepareCall(ctx, len(l.stack)-3, callOptions{numResults: 1})
if err != nil {
return err
}
Expand Down
Loading

0 comments on commit eff4bd7

Please sign in to comment.