From 8a08feab36ce60ee3bbdcc6a12b7a2ca1dc37a1f Mon Sep 17 00:00:00 2001 From: Roxy Light Date: Tue, 10 Dec 2024 23:35:46 -0800 Subject: [PATCH] Lazily evaluate registers Fixes aliasing issues when growing the stack as part of a metamethod. As a side-effect, add bounds checks on registers. --- internal/luacode/instruction.go | 9 +- internal/mylua/vm.go | 713 ++++++++++++++++++++++++-------- internal/mylua/vm_test.go | 74 ++++ 3 files changed, 622 insertions(+), 174 deletions(-) diff --git a/internal/luacode/instruction.go b/internal/luacode/instruction.go index 8b93052..af630cd 100644 --- a/internal/luacode/instruction.go +++ b/internal/luacode/instruction.go @@ -450,7 +450,14 @@ const ( // A B C R[A][K[B]:string] := RK(C) OpSetField OpCode = 18 // SETFIELD - // A B C k R[A] := {} + // OpNewTable creates a new table and stores it in R[A]. + // C is a hint that integer ("array") keys [1,C] are expected; + // 2^B is a hint as to how many non-"array" keys are expected. + // OpNewTable is always followed by an [OpExtraArg]. + // If k is set, then the array size is + // the extra argument multiplied by 256 plus C. + // + // A B C k R[A] := {} OpNewTable OpCode = 19 // NEWTABLE // A B C R[A+1] := R[B]; R[A] := R[B][RK(C):string] diff --git a/internal/mylua/vm.go b/internal/mylua/vm.go index 24b170b..15b2050 100644 --- a/internal/mylua/vm.go +++ b/internal/mylua/vm.go @@ -50,23 +50,21 @@ func (frame callFrame) extraArgumentsRange() (start, end int) { // and returns a slice of the value stack to be used as the function's registers. // loadLuaFrame returns an error if the value stack is not large enough // for the function's local registers. -func (l *State) loadLuaFrame() (frame *callFrame, f luaFunction, registers []value, err error) { +func (l *State) loadLuaFrame() (frame *callFrame, f luaFunction, err error) { frame = l.frame() v := l.stack[frame.functionIndex] f, ok := v.(luaFunction) if !ok { - return frame, luaFunction{}, nil, fmt.Errorf("internal error: call frame function is a %T", v) + return frame, luaFunction{}, fmt.Errorf("internal error: call frame function is a %T", v) } if err := l.checkUpvalues(f.upvalues); err != nil { - return frame, f, nil, err + return frame, f, err } - registerStart := frame.registerStart() - registerEnd := registerStart + int(f.proto.MaxStackSize) + registerEnd := frame.registerStart() + int(f.proto.MaxStackSize) if !l.grow(registerEnd) { - return frame, f, nil, errStackOverflow + return frame, f, errStackOverflow } - registers = l.stack[registerStart:registerEnd] - return frame, f, registers, nil + return frame, f, nil } func (l *State) checkUpvalues(upvalues []upvalue) error { @@ -79,34 +77,75 @@ func (l *State) checkUpvalues(upvalues []upvalue) error { return nil } -func (l *State) exec() error { +func (l *State) exec() (err error) { if len(l.callStack) == 0 { panic("exec called on empty call stack") } callerDepth := len(l.callStack) - 1 - defer func() { + frame, f, firstLoadError := l.loadLuaFrame() + defer func(callerValueTop int) { + if err != nil { + // TODO(someday): Message handler. + l.setTop(callerValueTop) + } + clear(l.callStack[callerDepth:]) l.callStack = l.callStack[:callerDepth] - }() + }(frame.framePointer()) + if firstLoadError != nil { + return firstLoadError + } + + // registers returns the slice of l.stack + // that represents the register file for the function at the top of the call stack. + registers := func() []value { + start := frame.registerStart() + return l.stack[start : start+int(f.proto.MaxStackSize)] + } + + // register returns a pointer to the element of l.stack + // for the i'th register of the function at the top of the call stack. + register := func(r []value, i uint8) (*value, error) { + if int(i) >= len(r) { + return nil, fmt.Errorf("decode instruction (pc=%d): register %d out-of-bounds (stack is %d slots)", + frame.pc, i, len(r)) + } + return &r[i], nil + } - frame, f, registers, err := l.loadLuaFrame() - callerValueTop := frame.framePointer() - if err != nil { - l.setTop(callerValueTop) - return err + constant := func(i uint32) (luacode.Value, error) { + if int64(i) >= int64(len(f.proto.Constants)) { + return luacode.Value{}, fmt.Errorf("decode instruction (pc=%d): constant %d out-of-bounds (table has %d entries)", frame.pc, i, len(f.proto.Constants)) + } + return f.proto.Constants[i], nil } - rkC := func(i luacode.Instruction) value { + fUpvalue := func(i uint8) (*value, error) { + if int(i) >= len(f.upvalues) { + return nil, fmt.Errorf("decode instruction (pc=%d): upvalue %d out-of-bounds (function has %d upvalues)", frame.pc, i, len(f.upvalues)) + } + return l.resolveUpvalue(f.upvalues[i]), nil + } + + rkC := func(r []value, i luacode.Instruction) (value, error) { + c := i.ArgC() if i.K() { - return importConstant(f.proto.Constants[i.ArgC()]) + kc, err := constant(uint32(c)) + if err != nil { + return nil, err + } + return importConstant(kc), nil } else { - return registers[i.ArgC()] + rc, err := register(r, c) + if err != nil { + return nil, err + } + return *rc, nil } } for len(l.callStack) > callerDepth { if frame.pc >= len(f.proto.Code) { - l.setTop(callerValueTop) return fmt.Errorf("jumped out of bounds") } i := f.proto.Code[frame.pc] @@ -120,103 +159,254 @@ func (l *State) exec() error { switch i.OpCode() { case luacode.OpMove: - registers[i.ArgA()] = registers[i.ArgB()] + r := registers() + ra, err := register(r, i.ArgA()) + if err != nil { + return err + } + rb, err := register(r, i.ArgB()) + if err != nil { + return err + } + *ra = *rb case luacode.OpLoadI: - registers[i.ArgA()] = integerValue(i.ArgBx()) + ra, err := register(registers(), i.ArgA()) + if err != nil { + return err + } + *ra = integerValue(i.ArgBx()) case luacode.OpLoadF: - registers[i.ArgA()] = floatValue(i.ArgBx()) + ra, err := register(registers(), i.ArgA()) + if err != nil { + return err + } + *ra = floatValue(i.ArgBx()) case luacode.OpLoadK: - registers[i.ArgA()] = importConstant(f.proto.Constants[i.ArgBx()]) + ra, err := register(registers(), i.ArgA()) + if err != nil { + return err + } + kb, err := constant(uint32(i.ArgBx())) + if err != nil { + return err + } + *ra = importConstant(kb) case luacode.OpLoadKX: + ra, err := register(registers(), i.ArgA()) + if err != nil { + return err + } arg, err := decodeExtraArg(frame, f.proto) if err != nil { - l.setTop(callerValueTop) return err } - registers[i.ArgA()] = importConstant(f.proto.Constants[arg]) + nextPC++ // Skip extra arg. + karg, err := constant(arg) + if err != nil { + return err + } + *ra = importConstant(karg) case luacode.OpLoadFalse: - registers[i.ArgA()] = booleanValue(false) + ra, err := register(registers(), i.ArgA()) + if err != nil { + return err + } + *ra = booleanValue(false) case luacode.OpLFalseSkip: - registers[i.ArgA()] = booleanValue(false) + ra, err := register(registers(), i.ArgA()) + if err != nil { + return err + } + *ra = booleanValue(false) nextPC++ case luacode.OpLoadTrue: - registers[i.ArgA()] = booleanValue(true) + ra, err := register(registers(), i.ArgA()) + if err != nil { + return err + } + *ra = booleanValue(true) case luacode.OpLoadNil: - clear(registers[i.ArgA() : i.ArgA()+i.ArgB()]) + start := i.ArgA() + end := start + i.ArgB() + if end > start { + r := registers() + if _, err := register(r, end-1); err != nil { + return err + } + clear(r[start:end]) + } case luacode.OpGetUpval: - p := l.resolveUpvalue(f.upvalues[i.ArgB()]) - registers[i.ArgA()] = *p + ra, err := register(registers(), i.ArgA()) + if err != nil { + return err + } + ub, err := fUpvalue(i.ArgB()) + if err != nil { + return err + } + *ra = *ub case luacode.OpSetUpval: - p := l.resolveUpvalue(f.upvalues[i.ArgB()]) - *p = registers[i.ArgA()] + ra, err := register(registers(), i.ArgA()) + if err != nil { + return err + } + ub, err := fUpvalue(i.ArgB()) + if err != nil { + return err + } + *ub = *ra case luacode.OpGetTabUp: - u := l.resolveUpvalue(f.upvalues[i.ArgB()]) - var err error - registers[i.ArgA()], err = l.index(*u, importConstant(f.proto.Constants[i.ArgC()])) + if _, err := register(registers(), i.ArgA()); err != nil { + return err + } + ub, err := fUpvalue(i.ArgB()) if err != nil { - l.setTop(callerValueTop) return err } + kc, err := constant(uint32(i.ArgC())) + if err != nil { + return err + } + result, err := l.index(*ub, importConstant(kc)) + if err != nil { + return err + } + // index may call a metamethod and grow the stack, + // so get register address afterward to avoid referencing an old array. + ra, err := register(registers(), i.ArgA()) + if err != nil { + return err + } + *ra = result case luacode.OpGetTable: - var err error - registers[i.ArgA()], err = l.index(registers[i.ArgB()], registers[i.ArgC()]) + r := registers() + if _, err := register(r, i.ArgA()); err != nil { + return err + } + rb, err := register(r, i.ArgB()) + if err != nil { + return err + } + rc, err := register(r, i.ArgC()) + if err != nil { + return err + } + result, err := l.index(*rb, *rc) + if err != nil { + return err + } + // index may call a metamethod and grow the stack, + // so get register address afterward to avoid referencing an old array. + ra, err := register(registers(), i.ArgA()) if err != nil { - l.setTop(callerValueTop) return err } + *ra = result case luacode.OpGetI: - var err error - registers[i.ArgA()], err = l.index(registers[i.ArgB()], integerValue(i.ArgC())) + r := registers() + if _, err := register(r, i.ArgA()); err != nil { + return err + } + rb, err := register(r, i.ArgB()) if err != nil { - l.setTop(callerValueTop) return err } + result, err := l.index(*rb, integerValue(i.ArgC())) + if err != nil { + return err + } + // index may call a metamethod and grow the stack, + // so get register address afterward to avoid referencing an old array. + ra, err := register(registers(), i.ArgA()) + if err != nil { + return err + } + *ra = result case luacode.OpGetField: - var err error - registers[i.ArgA()], err = l.index(registers[i.ArgB()], importConstant(f.proto.Constants[i.ArgC()])) + r := registers() + if _, err := register(r, i.ArgA()); err != nil { + return err + } + rb, err := register(r, i.ArgB()) + if err != nil { + return err + } + kc, err := constant(uint32(i.ArgC())) + if err != nil { + return err + } + result, err := l.index(*rb, importConstant(kc)) if err != nil { - l.setTop(callerValueTop) return err } + // index may call a metamethod and grow the stack, + // so get register address afterward to avoid referencing an old array. + ra, err := register(registers(), i.ArgA()) + if err != nil { + return err + } + *ra = result case luacode.OpSetTabUp: - u := l.resolveUpvalue(f.upvalues[i.ArgA()]) - err := l.setIndex( - *u, - importConstant(f.proto.Constants[i.ArgB()]), - rkC(i), - ) + ua, err := fUpvalue(i.ArgA()) if err != nil { - l.setTop(callerValueTop) + return err + } + kb, err := constant(uint32(i.ArgB())) + if err != nil { + return err + } + c, err := rkC(registers(), i) + if err != nil { + return err + } + if err := l.setIndex(*ua, importConstant(kb), c); err != nil { return err } case luacode.OpSetTable: - err := l.setIndex( - registers[i.ArgA()], - registers[i.ArgB()], - rkC(i), - ) + r := registers() + ra, err := register(r, i.ArgA()) + if err != nil { + return err + } + rb, err := register(r, i.ArgA()) if err != nil { - l.setTop(callerValueTop) + return err + } + c, err := rkC(registers(), i) + if err != nil { + return err + } + if err := l.setIndex(*ra, *rb, c); err != nil { return err } case luacode.OpSetI: - err := l.setIndex( - registers[i.ArgA()], - integerValue(i.ArgB()), - rkC(i), - ) + r := registers() + ra, err := register(r, i.ArgA()) + if err != nil { + return err + } + c, err := rkC(registers(), i) if err != nil { - l.setTop(callerValueTop) + return err + } + if err := l.setIndex(*ra, integerValue(i.ArgB()), c); err != nil { return err } case luacode.OpSetField: - err := l.setIndex( - registers[i.ArgA()], - importConstant(f.proto.Constants[i.ArgB()]), - rkC(i), - ) + r := registers() + ra, err := register(r, i.ArgA()) + if err != nil { + return err + } + kb, err := constant(uint32(i.ArgB())) if err != nil { - l.setTop(callerValueTop) + return err + } + c, err := rkC(registers(), i) + if err != nil { + return err + } + if err := l.setIndex(*ra, importConstant(kb), c); err != nil { return err } case luacode.OpNewTable: @@ -229,47 +419,87 @@ func (l *State) exec() error { if i.K() { arg, err := decodeExtraArg(frame, f.proto) if err != nil { - l.setTop(callerValueTop) return err } arraySize += int(arg) * (1 << 8) } - registers[i.ArgA()] = newTable(hashSize + arraySize) + nextPC++ // Extra arg is always present even if unused. + + ra, err := register(registers(), i.ArgA()) + if err != nil { + return err + } + *ra = newTable(hashSize + arraySize) case luacode.OpSelf: - rb := registers[i.ArgB()] - registers[int(i.ArgA())+1] = rb - var err error - registers[i.ArgA()], err = l.index(rb, rkC(i)) + r := registers() + a := i.ArgA() + ra1, err := register(r, a+1) + if err != nil { + return err + } + rb, err := register(r, i.ArgB()) + if err != nil { + return err + } + c, err := rkC(r, i) + if err != nil { + return err + } + + *ra1 = *rb + result, err := l.index(*rb, c) if err != nil { - l.setTop(callerValueTop) return err } + // index may call a metamethod and grow the stack, + // so get register address afterward to avoid referencing an old array. + ra, err := register(registers(), a) + if err != nil { + return err + } + *ra = result case luacode.OpAddI, luacode.OpSHRI: + r := registers() + ra, err := register(r, i.ArgA()) + if err != nil { + return err + } + rb, err := register(r, i.ArgB()) + if err != nil { + return err + } c := luacode.IntegerValue(int64(luacode.SignedArg(i.ArgC()))) - if rb, isNumber := exportNumericConstant(registers[i.ArgB()]); isNumber { + if kb, isNumber := exportNumericConstant(*rb); isNumber { op, ok := i.OpCode().ArithmeticOperator() if !ok { panic("operator should always be defined") } - result, err := luacode.Arithmetic(op, rb, c) + result, err := luacode.Arithmetic(op, kb, c) if err != nil { - l.setTop(callerValueTop) return err } - registers[i.ArgA()] = importConstant(result) + *ra = importConstant(result) // The next instruction is a fallback metamethod invocation. nextPC++ } case luacode.OpSHLI: // Separate case because SHLI's arguments are in the opposite order. + r := registers() + ra, err := register(r, i.ArgA()) + if err != nil { + return err + } + rb, err := register(r, i.ArgB()) + if err != nil { + return err + } c := luacode.IntegerValue(int64(luacode.SignedArg(i.ArgC()))) - if rb, isNumber := exportNumericConstant(registers[i.ArgB()]); isNumber { - result, err := luacode.Arithmetic(luacode.ShiftLeft, c, rb) + if kb, isNumber := exportNumericConstant(*rb); isNumber { + result, err := luacode.Arithmetic(luacode.ShiftLeft, c, kb) if err != nil { - l.setTop(callerValueTop) return err } - registers[i.ArgA()] = importConstant(result) + *ra = importConstant(result) // The next instruction is a fallback metamethod invocation. nextPC++ } @@ -283,22 +513,32 @@ func (l *State) exec() error { luacode.OpBAndK, luacode.OpBOrK, luacode.OpBXORK: - kc := f.proto.Constants[i.ArgC()] + r := registers() + ra, err := register(r, i.ArgA()) + if err != nil { + return err + } + rb, err := register(r, i.ArgB()) + if err != nil { + return err + } + kc, err := constant(uint32(i.ArgC())) + if err != nil { + return err + } if !kc.IsNumber() { - l.setTop(callerValueTop) - return fmt.Errorf("%v on non-numeric constant %v", i.OpCode(), kc) + return fmt.Errorf("decode instruction (pc=%d): %v on non-numeric constant %v", frame.pc, i.OpCode(), kc) } - if rb, isNumber := exportNumericConstant(registers[i.ArgB()]); isNumber { + if rb, isNumber := exportNumericConstant(*rb); isNumber { op, ok := i.OpCode().ArithmeticOperator() if !ok { panic("operator should always be defined") } result, err := luacode.Arithmetic(op, rb, kc) if err != nil { - l.setTop(callerValueTop) return err } - registers[i.ArgA()] = importConstant(result) + *ra = importConstant(result) // The next instruction is a fallback metamethod invocation. nextPC++ } @@ -314,134 +554,245 @@ func (l *State) exec() error { luacode.OpBXOR, luacode.OpSHL, luacode.OpSHR: - if rb, isNumber := exportNumericConstant(registers[i.ArgB()]); isNumber { - if rc, isNumber := exportNumericConstant(registers[i.ArgC()]); isNumber { + r := registers() + ra, err := register(r, i.ArgA()) + if err != nil { + return err + } + rb, err := register(r, i.ArgB()) + if err != nil { + return err + } + rc, err := register(r, i.ArgC()) + if err != nil { + return err + } + if kb, isNumber := exportNumericConstant(*rb); isNumber { + if kc, isNumber := exportNumericConstant(*rc); isNumber { op, ok := i.OpCode().ArithmeticOperator() if !ok { panic("operator should always be defined") } - result, err := luacode.Arithmetic(op, rb, rc) + result, err := luacode.Arithmetic(op, kb, kc) if err != nil { - l.setTop(callerValueTop) return err } - registers[i.ArgA()] = importConstant(result) + *ra = importConstant(result) // The next instruction is a fallback metamethod invocation. nextPC++ } } case luacode.OpMMBin: - prev, prevOperator, err := decodeBinaryMetamethod(frame, f.proto) + resultRegister, prevOperator, err := decodeBinaryMetamethod(frame, f.proto) if err != nil { - l.setTop(callerValueTop) return err } - registers[prev.ArgA()], err = l.arithmeticMetamethod( - prevOperator.TagMethod(), - registers[i.ArgA()], - registers[i.ArgB()], - ) + r := registers() + if _, err := register(r, resultRegister); err != nil { + return err + } + ra, err := register(r, i.ArgA()) if err != nil { - l.setTop(callerValueTop) return err } + rb, err := register(r, i.ArgB()) + if err != nil { + return err + } + + result, err := l.arithmeticMetamethod(prevOperator.TagMethod(), *ra, *rb) + if err != nil { + return err + } + // Calling a metamethod may grow the stack, + // so get register address afterward to avoid referencing an old array. + prevRA, err := register(registers(), resultRegister) + if err != nil { + return err + } + *prevRA = result case luacode.OpMMBinI: - prev, prevOperator, err := decodeBinaryMetamethod(frame, f.proto) + resultRegister, prevOperator, err := decodeBinaryMetamethod(frame, f.proto) + if err != nil { + return err + } + r := registers() + if _, err := register(r, resultRegister); err != nil { + return err + } + ra, err := register(r, i.ArgA()) if err != nil { - l.setTop(callerValueTop) return err } - registers[prev.ArgA()], err = l.arithmeticMetamethod( + result, err := l.arithmeticMetamethod( prevOperator.TagMethod(), - registers[i.ArgA()], + *ra, integerValue(luacode.SignedArg(i.ArgB())), ) if err != nil { - l.setTop(callerValueTop) return err } + // Calling a metamethod may grow the stack, + // so get register address afterward to avoid referencing an old array. + prevRA, err := register(registers(), resultRegister) + if err != nil { + return err + } + *prevRA = result case luacode.OpMMBinK: - prev, prevOperator, err := decodeBinaryMetamethod(frame, f.proto) + resultRegister, prevOperator, err := decodeBinaryMetamethod(frame, f.proto) + if err != nil { + return err + } + r := registers() + if _, err := register(r, resultRegister); err != nil { + return err + } + ra, err := register(r, i.ArgA()) if err != nil { - l.setTop(callerValueTop) return err } - registers[prev.ArgA()], err = l.arithmeticMetamethod( + kb, err := constant(uint32(i.ArgB())) + if err != nil { + return err + } + result, err := l.arithmeticMetamethod( prevOperator.TagMethod(), - registers[i.ArgA()], - importConstant(f.proto.Constants[i.ArgB()]), + *ra, + importConstant(kb), ) if err != nil { - l.setTop(callerValueTop) return err } + // Calling a metamethod may grow the stack, + // so get register address afterward to avoid referencing an old array. + prevRA, err := register(registers(), resultRegister) + if err != nil { + return err + } + *prevRA = result case luacode.OpUNM: - rb := registers[i.ArgB()] - if ib, ok := rb.(integerValue); ok { - registers[i.ArgA()] = -ib - } else if nb, ok := toNumber(rb); ok { - registers[i.ArgA()] = -nb + r := registers() + a := i.ArgA() + ra, err := register(r, a) + if err != nil { + return err + } + rb, err := register(r, i.ArgB()) + if err != nil { + return err + } + + if ib, ok := (*rb).(integerValue); ok { + *ra = -ib + } else if nb, ok := toNumber(*rb); ok { + *ra = -nb } else { - var err error - registers[i.ArgA()], err = l.arithmeticMetamethod(luacode.TagMethodUNM, rb, rb) + result, err := l.arithmeticMetamethod(luacode.TagMethodUNM, *rb, *rb) if err != nil { - l.setTop(callerValueTop) return err } + // Calling a metamethod may grow the stack, + // so get register address afterward to avoid referencing an old array. + ra, err = register(registers(), a) + if err != nil { + return err + } + *ra = result } case luacode.OpBNot: - rb := registers[i.ArgB()] - if ib, ok := rb.(integerValue); ok { - registers[i.ArgA()] = integerValue(^uint64(ib)) + r := registers() + a := i.ArgA() + ra, err := register(r, a) + if err != nil { + return err + } + rb, err := register(r, i.ArgB()) + if err != nil { + return err + } + + if ib, ok := (*rb).(integerValue); ok { + kb := luacode.IntegerValue(int64(ib)) + result, err := luacode.Arithmetic(luacode.BitwiseNot, kb, luacode.Value{}) + if err != nil { + return err + } + *ra = importConstant(result) } else { - var err error - registers[i.ArgA()], err = l.arithmeticMetamethod(luacode.TagMethodBNot, rb, rb) + result, err := l.arithmeticMetamethod(luacode.TagMethodBNot, *rb, *rb) if err != nil { - l.setTop(callerValueTop) return err } + // Calling a metamethod may grow the stack, + // so get register address afterward to avoid referencing an old array. + ra, err = register(registers(), a) + if err != nil { + return err + } + *ra = result } case luacode.OpNot: - registers[i.ArgA()] = booleanValue(!toBoolean(registers[i.ArgB()])) + r := registers() + ra, err := register(r, i.ArgA()) + if err != nil { + return err + } + rb, err := register(r, i.ArgB()) + if err != nil { + return err + } + *ra = booleanValue(!toBoolean(*rb)) case luacode.OpJMP: nextPC += int(i.J()) case luacode.OpTest: - cond := toBoolean(registers[i.ArgA()]) + ra, err := register(registers(), i.ArgA()) + if err != nil { + return err + } + cond := toBoolean(*ra) if cond != i.K() { nextPC++ } case luacode.OpTestSet: - rb := registers[i.ArgB()] - cond := toBoolean(rb) + r := registers() + ra, err := register(r, i.ArgA()) + if err != nil { + return err + } + rb, err := register(r, i.ArgB()) + if err != nil { + return err + } + cond := toBoolean(*rb) if cond != i.K() { nextPC++ } else { - registers[i.ArgA()] = rb + *ra = *rb } case luacode.OpCall: numArguments := int(i.ArgB()) numResults := int(i.ArgC()) - 1 + // TODO(soon): Validate ArgA. l.setTop(frame.registerStart() + int(i.ArgA()) + 1 + numArguments) isLua, err := l.prepareCall(numArguments, numResults) if err != nil { - l.setTop(callerValueTop) return err } if isLua { - frame, f, registers, err = l.loadLuaFrame() + frame, f, err = l.loadLuaFrame() if err != nil { - l.setTop(callerValueTop) return err } } case luacode.OpReturn: + // TODO(soon): Validate ArgA+numResults. resultStackStart := frame.registerStart() + int(i.ArgA()) numResults := int(i.ArgB()) - 1 if numResults < 0 { numResults = len(l.stack) - resultStackStart } if i.K() { - l.setTop(callerValueTop) return errors.New("TODO(soon): close upvalues") } @@ -450,37 +801,62 @@ func (l *State) exec() error { if len(l.callStack) <= callerDepth { return nil } - frame, f, registers, err = l.loadLuaFrame() + frame, f, err = l.loadLuaFrame() if err != nil { - l.setTop(callerValueTop) return err } case luacode.OpReturn0: l.finishCall(0) - frame, f, registers, err = l.loadLuaFrame() + frame, f, err = l.loadLuaFrame() if err != nil { - l.setTop(callerValueTop) return err } case luacode.OpReturn1: + // TODO(soon): Validate ArgA. l.setTop(frame.registerStart() + int(i.ArgA()) + 1) l.finishCall(1) - frame, f, registers, err = l.loadLuaFrame() + frame, f, err = l.loadLuaFrame() if err != nil { - l.setTop(callerValueTop) return err } case luacode.OpSetList: - t := registers[i.ArgA()] - for idx := range i.ArgB() { - err := l.setIndex(t, integerValue(idx)+1, registers[i.ArgC()+idx+1]) + a := i.ArgA() + ra, err := register(registers(), a) + if err != nil { + return err + } + t, isTable := (*ra).(*table) + if !isTable { + return fmt.Errorf("set list (pc=%d): value in register %d is a %s (need table)", frame.pc, i.ArgA(), l.typeName(*ra)) + } + n := int(i.ArgB()) + stackBase := frame.registerStart() + int(a) + 1 + if n == 0 { + n = len(l.stack) - stackBase + } else if int(a)+1+n > int(f.proto.MaxStackSize) { + return fmt.Errorf("decode instruction (pc=%d): set list (a=%d n=%d) overflows stack (size=%d)", + frame.pc, a, n, f.proto.MaxStackSize) + } + indexBase := integerValue(i.ArgC()) + 1 + + for idx := range n { + // TODO(soon): We can do a much more efficient bulk insert here. + err := t.set(indexBase+integerValue(idx), l.stack[stackBase+idx]) if err != nil { - l.setTop(callerValueTop) return err } } case luacode.OpClosure: + ra, err := register(registers(), i.ArgA()) + if err != nil { + return err + } + bx := i.ArgBx() + if int(bx) >= len(f.proto.Functions) { + return fmt.Errorf("decode instruction (pc=%d): closure %d out of range", frame.pc, bx) + } p := f.proto.Functions[i.ArgBx()] + upvalues := make([]upvalue, len(p.Upvalues)) for i, uv := range p.Upvalues { if uv.InStack { @@ -489,7 +865,7 @@ func (l *State) exec() error { upvalues[i] = f.upvalues[uv.Index] } } - registers[i.ArgA()] = luaFunction{ + *ra = luaFunction{ id: nextID(), proto: p, upvalues: upvalues, @@ -501,7 +877,6 @@ func (l *State) exec() error { } a := frame.registerStart() + int(i.ArgA()) if !l.grow(a + numWanted) { - l.setTop(callerValueTop) return errStackOverflow } l.setTop(a + numWanted) @@ -510,11 +885,9 @@ func (l *State) exec() error { clear(l.stack[a+n:]) case luacode.OpVarargPrep: if frame.pc != 0 { - l.setTop(callerValueTop) return fmt.Errorf("%v must be first instruction in function", luacode.OpVarargPrep) } if frame.numExtraArguments != 0 { - l.setTop(callerValueTop) return fmt.Errorf("cannot run %v multiple times", luacode.OpVarargPrep) } numArguments := len(l.stack) - frame.registerStart() @@ -524,18 +897,9 @@ func (l *State) exec() error { rotate(l.stack[frame.functionIndex:], numExtraArguments) frame.functionIndex += numExtraArguments frame.numExtraArguments = numExtraArguments - - // Reload frame to update register slice. - frame, f, registers, err = l.loadLuaFrame() - if err != nil { - l.setTop(callerValueTop) - return err - } } - case luacode.OpExtraArg: - return fmt.Errorf("unexpected %v at pc %d", luacode.OpExtraArg, frame.pc) default: - return fmt.Errorf("unhandled instruction %v", i.OpCode()) + return fmt.Errorf("decode instruction (pc=%d): unhandled instruction %v", frame.pc, i.OpCode()) } frame.pc = nextPC @@ -596,33 +960,36 @@ func (l *State) arithmeticMetamethod(event luacode.TagMethod, arg1, arg2 value) return nil, fmt.Errorf("attempt to perform %s on a %s value", kind, tname) } -func decodeBinaryMetamethod(frame *callFrame, proto *luacode.Prototype) (luacode.Instruction, luacode.ArithmeticOperator, error) { +func decodeBinaryMetamethod(frame *callFrame, proto *luacode.Prototype) (uint8, luacode.ArithmeticOperator, error) { i := proto.Code[frame.pc] if frame.pc <= 0 { - return 0, 0, fmt.Errorf("decode error: %v must be preceded by binary arithmetic instruction", i.OpCode()) + return 0, 0, fmt.Errorf("decode instruction (pc=%d): %v must be preceded by binary arithmetic instruction", frame.pc, i.OpCode()) } prev := proto.Code[frame.pc-1] prevOpCode := prev.OpCode() prevOperator, isArithmetic := prevOpCode.ArithmeticOperator() if !isArithmetic || !prevOperator.IsBinary() { - return 0, 0, fmt.Errorf("decode error: %v must be preceded by binary arithmetic instruction (found %v)", i.OpCode(), prevOpCode) + return 0, 0, fmt.Errorf("decode instruction (pc=%d): %v must be preceded by binary arithmetic instruction (found %v)", + frame.pc, i.OpCode(), prevOpCode) } if got, want := luacode.TagMethod(i.ArgC()), prevOperator.TagMethod(); got != want { - err := fmt.Errorf("decode error: found metamethod %v in %v after %v (expected %v)", - got, i.OpCode(), prev.OpCode(), want) - return prev, prevOperator, err + err := fmt.Errorf("decode instruction (pc=%d): found metamethod %v in %v after %v (expected %v)", + frame.pc, got, i.OpCode(), prev.OpCode(), want) + return prev.ArgA(), prevOperator, err } - return prev, prevOperator, nil + return prev.ArgA(), prevOperator, nil } func decodeExtraArg(frame *callFrame, proto *luacode.Prototype) (uint32, error) { argPC := frame.pc + 1 if argPC >= len(proto.Code) { - return 0, fmt.Errorf("%v (last instruction) expects extra argument", proto.Code[frame.pc].OpCode()) + return 0, fmt.Errorf("decode instruction (pc=%d, last): %v expects extra argument", + frame.pc, proto.Code[frame.pc].OpCode()) } i := proto.Code[argPC] if got := i.OpCode(); got != luacode.OpExtraArg { - return 0, fmt.Errorf("%v expects extra argument (found %v)", proto.Code[frame.pc].OpCode(), got) + return 0, fmt.Errorf("decode instruction (pc=%d): %v expects extra argument (found %v)", + frame.pc, proto.Code[frame.pc].OpCode(), got) } return i.ArgAx(), nil } diff --git a/internal/mylua/vm_test.go b/internal/mylua/vm_test.go index 799572c..ec108de 100644 --- a/internal/mylua/vm_test.go +++ b/internal/mylua/vm_test.go @@ -6,6 +6,8 @@ package mylua import ( "strings" "testing" + + "zb.256lights.llc/pkg/internal/luacode" ) func TestVM(t *testing.T) { @@ -92,4 +94,76 @@ func TestVM(t *testing.T) { t.Errorf("state.ToInteger(-1) = %d, %t; want %d, true", got, ok, want) } }) + + t.Run("SetListSmall", func(t *testing.T) { + state := new(State) + defer func() { + if err := state.Close(); err != nil { + t.Error("Close:", err) + } + }() + + const source = `return {"abc", 42, 3.14}` + if err := state.Load(strings.NewReader(source), source, "t"); err != nil { + t.Fatal(err) + } + if err := state.Call(0, 1, 0); err != nil { + t.Fatal(err) + } + if !state.IsTable(-1) { + t.Fatalf("top of stack is %v; want table", state.Type(-1)) + } + if got, want := state.RawLen(-1), 3; got != uint64(want) { + t.Errorf("table size = %d; want %d", got, want) + } + + state.RawIndex(-1, 1) + if got, ok := state.ToString(-1); got != "abc" || !ok { + t.Errorf("table[1] = %q, %t; want %q, true", got, ok, "abc") + } + state.Pop(1) + + state.RawIndex(-1, 2) + if got, ok := state.ToInteger(-1); got != 42 || !ok { + t.Errorf("table[2] = %d, %t; want %d, true", got, ok, 42) + } + state.Pop(1) + + state.RawIndex(-1, 3) + if got, ok := state.ToNumber(-1); got != 3.14 || !ok { + t.Errorf("table[3] = %g, %t; want %g, true", got, ok, 3.14) + } + state.Pop(1) + }) + + t.Run("SetListLarge", func(t *testing.T) { + state := new(State) + defer func() { + if err := state.Close(); err != nil { + t.Error("Close:", err) + } + }() + + const wantLength = 256 + source := "return {42" + strings.Repeat(",42", wantLength-1) + "}" + if err := state.Load(strings.NewReader(source), luacode.Source(source), "t"); err != nil { + t.Fatal(err) + } + if err := state.Call(0, 1, 0); err != nil { + t.Fatal(err) + } + if !state.IsTable(-1) { + t.Fatalf("top of stack is %v; want table", state.Type(-1)) + } + if got := state.RawLen(-1); got != uint64(wantLength) { + t.Errorf("table size = %d; want %d", got, wantLength) + } + for i := int64(1); i <= wantLength; i++ { + state.RawIndex(-1, i) + if got, ok := state.ToInteger(-1); got != 42 || !ok { + t.Errorf("table[%d] = %d, %t; want %d, true", i, got, ok, 42) + } + state.Pop(1) + } + }) }