diff --git a/internal/luacode/prototype.go b/internal/luacode/prototype.go index 78929ba..ab433bc 100644 --- a/internal/luacode/prototype.go +++ b/internal/luacode/prototype.go @@ -263,11 +263,18 @@ func (f *Prototype) UnmarshalBinary(data []byte) error { return nil } +// UpvalueDescriptor describes an upvalue in a [Prototype]. type UpvalueDescriptor struct { - Name string + Name string + // InStack is true if the upvalue refers to a local variable + // in the containing function. + // Otherwise, the upvalue refers to an upvalue in the containing function. InStack bool - Index uint8 - Kind VariableKind + // Index is the index of the local variable or upvalue + // to initialize the upvalue to. + // Its interpretation depends on the value of InStack. + Index uint8 + Kind VariableKind } type VariableKind uint8 diff --git a/internal/mylua/lua.go b/internal/mylua/lua.go index 977ed7d..406e3ff 100644 --- a/internal/mylua/lua.go +++ b/internal/mylua/lua.go @@ -499,7 +499,7 @@ func (l *State) ToString(idx int) (s string, ok bool) { if i > len(upvalues) { return "", false } - p = &upvalues[i-1] + p = l.resolveUpvalue(upvalues[i-1]) case isPseudo(idx): return "", false default: @@ -640,7 +640,10 @@ func (l *State) PushClosure(n int, f Function) { panic("too many upvalues") } upvalueStart := len(l.stack) - n - upvalues := slices.Clone(l.stack[upvalueStart:]) + upvalues := make([]upvalue, 0, n) + for _, v := range l.stack[upvalueStart:] { + upvalues = append(upvalues, standaloneUpvalue(v)) + } l.setTop(upvalueStart) l.push(goFunction{ id: nextID(), @@ -1114,6 +1117,11 @@ func (l *State) prepareCall(numArgs, numResults int) (isLua bool, err error) { for range maxMetaDepth { switch f := l.stack[functionIndex].(type) { case luaFunction: + if err := l.checkUpvalues(f.upvalues); err != nil { + l.popCallStack() + l.setTop(functionIndex - 1) + return true, err + } if !l.grow(len(l.stack) + int(f.proto.MaxStackSize) - numArgs) { l.popCallStack() l.setTop(functionIndex - 1) @@ -1121,6 +1129,11 @@ func (l *State) prepareCall(numArgs, numResults int) (isLua bool, err error) { } return true, nil case goFunction: + if err := l.checkUpvalues(f.upvalues); err != nil { + l.popCallStack() + l.setTop(functionIndex - 1) + return false, err + } if !l.grow(len(l.stack) + minStack) { l.popCallStack() l.setTop(functionIndex - 1) @@ -1222,9 +1235,11 @@ func (l *State) Load(r io.Reader, chunkName luacode.Source, mode string) (err er l.init() l.push(luaFunction{ - id: nextID(), - proto: p, - upvalues: []any{l.registry.get(RegistryIndexGlobals)}, + id: nextID(), + proto: p, + upvalues: []upvalue{ + standaloneUpvalue(l.registry.get(RegistryIndexGlobals)), + }, }) return nil } diff --git a/internal/mylua/value.go b/internal/mylua/value.go index ee971b4..a2a18ad 100644 --- a/internal/mylua/value.go +++ b/internal/mylua/value.go @@ -323,34 +323,34 @@ type stringValue struct { type goFunction struct { id uint64 cb Function - upvalues []any + upvalues []upvalue } func (f goFunction) functionID() uint64 { return f.id } -func (f goFunction) upvaluesSlice() []any { +func (f goFunction) upvaluesSlice() []upvalue { return f.upvalues } type luaFunction struct { id uint64 proto *luacode.Prototype - upvalues []any + upvalues []upvalue } func (f luaFunction) functionID() uint64 { return f.id } -func (f luaFunction) upvaluesSlice() []any { +func (f luaFunction) upvaluesSlice() []upvalue { return f.upvalues } type function interface { functionID() uint64 - upvaluesSlice() []any + upvaluesSlice() []upvalue } var ( @@ -358,6 +358,29 @@ var ( _ function = luaFunction{} ) +type upvalue struct { + p *any + stackIndex int +} + +func stackUpvalue(i int) upvalue { + return upvalue{stackIndex: i} +} + +func standaloneUpvalue(v any) upvalue { + return upvalue{ + p: &v, + stackIndex: -1, + } +} + +func (l *State) resolveUpvalue(uv upvalue) *any { + if uv.p == nil { + return &l.stack[uv.stackIndex] + } + return uv.p +} + var globalIDs struct { mu sync.Mutex n uint64 diff --git a/internal/mylua/vm.go b/internal/mylua/vm.go index 996b4e1..9877079 100644 --- a/internal/mylua/vm.go +++ b/internal/mylua/vm.go @@ -57,6 +57,9 @@ func (l *State) loadLuaFrame() (frame *callFrame, f luaFunction, registers []any if !ok { return frame, luaFunction{}, nil, fmt.Errorf("internal error: call frame function is a %T", v) } + if err := l.checkUpvalues(f.upvalues); err != nil { + return frame, f, nil, err + } registerStart := frame.registerStart() registerEnd := registerStart + int(f.proto.MaxStackSize) if !l.grow(registerEnd) { @@ -66,6 +69,16 @@ func (l *State) loadLuaFrame() (frame *callFrame, f luaFunction, registers []any return frame, f, registers, nil } +func (l *State) checkUpvalues(upvalues []upvalue) error { + frame := l.frame() + for i, uv := range upvalues { + if uv.stackIndex >= frame.framePointer() { + return fmt.Errorf("internal error: function upvalue [%d] inside current frame", i) + } + } + return nil +} + func (l *State) exec() error { if len(l.callStack) == 0 { panic("exec called on empty call stack") @@ -124,7 +137,7 @@ func (l *State) exec() error { registers[i.ArgA()] = false case luacode.OpLFalseSkip: registers[i.ArgA()] = false - nextPC = frame.pc + 2 + nextPC++ case luacode.OpLoadTrue: registers[i.ArgA()] = true case luacode.OpLoadNil: @@ -132,7 +145,8 @@ func (l *State) exec() error { case luacode.OpGetUpval: registers[i.ArgA()] = f.upvalues[i.ArgB()] case luacode.OpSetUpval: - f.upvalues[i.ArgB()] = registers[i.ArgA()] + p := l.resolveUpvalue(f.upvalues[i.ArgB()]) + *p = registers[i.ArgA()] case luacode.OpGetTabUp: var err error registers[i.ArgA()], err = l.index(f.upvalues[i.ArgB()], importConstant(f.proto.Constants[i.ArgC()])) @@ -226,6 +240,37 @@ func (l *State) exec() error { l.setTop(callerValueTop) return err } + case luacode.OpJmp: + nextPC += int(i.J()) + case luacode.OpTest: + cond := toBoolean(registers[i.ArgA()]) + if cond != i.K() { + nextPC++ + } + case luacode.OpTestSet: + rb := registers[i.ArgB()] + cond := toBoolean(rb) + if cond != i.K() { + nextPC++ + } else { + registers[i.ArgA()] = rb + } + case luacode.OpCall: + numArguments := int(i.ArgB()) + numResults := int(i.ArgC()) - 1 + l.setTop(frame.registerStart() + int(i.ArgA()) + 1 + numArguments) + isLua, err := l.prepareCall(numArguments, numResults) + if err != nil { + l.setTop(callerValueTop) + return err + } + if isLua { + frame, f, registers, err = l.loadLuaFrame() + if err != nil { + l.setTop(callerValueTop) + return err + } + } case luacode.OpReturn: resultStackStart := frame.registerStart() + int(i.ArgA()) numResults := int(i.ArgB()) - 1 @@ -262,6 +307,44 @@ func (l *State) exec() error { l.setTop(callerValueTop) return err } + case luacode.OpSetList: + t := registers[i.ArgA()] + for idx := range i.ArgB() { + err := l.setIndex(t, int64(idx)+1, registers[i.ArgC()+idx+1]) + if err != nil { + l.setTop(callerValueTop) + return err + } + } + case luacode.OpClosure: + p := f.proto.Functions[i.ArgBx()] + upvalues := make([]upvalue, len(p.Upvalues)) + for i, uv := range p.Upvalues { + if uv.InStack { + upvalues[i] = stackUpvalue(frame.registerStart() + int(uv.Index)) + } else { + upvalues[i] = f.upvalues[uv.Index] + } + } + registers[i.ArgA()] = luaFunction{ + id: nextID(), + proto: p, + upvalues: upvalues, + } + case luacode.OpVararg: + numWanted := int(i.ArgC()) - 1 + if numWanted == MultipleReturns { + numWanted = frame.numExtraArguments + } + a := frame.registerStart() + int(i.ArgA()) + if !l.grow(a + numWanted) { + l.setTop(callerValueTop) + return errStackOverflow + } + l.setTop(a + numWanted) + varargStart, varargEnd := frame.extraArgumentsRange() + n := copy(l.stack[a:], l.stack[varargStart:varargEnd]) + clear(l.stack[a+n:]) case luacode.OpVarargPrep: if frame.pc != 0 { l.setTop(callerValueTop)