diff --git a/internal/mylua/lua.go b/internal/mylua/lua.go index 4d9700c..efa9aad 100644 --- a/internal/mylua/lua.go +++ b/internal/mylua/lua.go @@ -1,6 +1,8 @@ // Copyright 2024 The zb Authors // SPDX-License-Identifier: MIT +//go:generate stringer -type=ComparisonOperator -linecomment -output=lua_string.go + package mylua import ( @@ -554,6 +556,124 @@ func (l *State) RawLen(idx int) uint64 { return uint64(lv.len()) } +// RawEqual reports whether the two values in the given indices +// are primitively equal (that is, equal without calling the __eq metamethod). +// If either index is invalid, then RawEqual reports false. +func (l *State) RawEqual(idx1, idx2 int) bool { + l.init() + v1, _, err := l.valueByIndex(idx1) + if err != nil { + return false + } + v2, _, err := l.valueByIndex(idx2) + if err != nil { + return false + } + return valuesEqual(v1, v2) +} + +// ComparisonOperator is an enumeration of operators +// that can be used with [*State.Compare]. +type ComparisonOperator int + +// Defined [ComparisonOperator] values. +const ( + Equal ComparisonOperator = iota // == + Less // < + LessOrEqual // <= +) + +func (l *State) Compare(idx1, idx2 int, op ComparisonOperator, msgHandler int) (bool, error) { + l.init() + v1, _, err := l.valueByIndex(idx1) + if err != nil { + return false, err + } + v2, _, err := l.valueByIndex(idx2) + if err != nil { + return false, err + } + if msgHandler != 0 { + return false, fmt.Errorf("TODO(someday): support message handlers") + } + return l.compare(op, v1, v2) +} + +// compare returns the result of comparing v1 and v2 with the given operator +// according to Lua's full comparison rules (including metamethods). +func (l *State) compare(op ComparisonOperator, v1, v2 value) (bool, error) { + switch op { + case Equal: + return l.equal(v1, v2) + case Less, LessOrEqual: + t1, t2 := valueType(v1), valueType(v2) + if t1 == TypeNumber && t2 == TypeNumber || t1 == TypeString && t2 == TypeString { + result := compareValues(v1, v2) + return result < 0 || result == 0 && op == LessOrEqual, nil + } + var eventName stringValue + switch op { + case Less: + eventName = stringValue{s: luacode.TagMethodLT.String()} + case LessOrEqual: + eventName = stringValue{s: luacode.TagMethodLE.String()} + default: + panic("unreachable") + } + f := l.metatable(v1).get(eventName) + if f == nil { + f = l.metatable(v2).get(eventName) + if f == nil { + // Neither value has the needed metamethod. + tn1 := l.typeName(v1) + tn2 := l.typeName(v2) + if tn1 == tn2 { + return false, fmt.Errorf("attempt to compare two %s values", tn1) + } + return false, fmt.Errorf("attempt to compare %s with %s", tn1, tn2) + } + } + result, err := l.call1(f, v1, v2) + if err != nil { + return false, err + } + return toBoolean(result), nil + default: + return false, fmt.Errorf("invalid %v", op) + } +} + +// equal reports whether v1 == v2 according to Lua's full equality rules (including metamethods). +func (l *State) equal(v1, v2 value) (bool, error) { + // Values of different types are never equal. + t1, t2 := valueType(v1), valueType(v2) + if t1 != t2 { + return false, nil + } + // If the values are primitively equal, then it's equal. + if valuesEqual(v1, v2) { + return true, nil + } + // Check __eq metamethod for types with individual metatables. + if !(t1 == TypeTable || t1 == TypeUserdata) { + return false, nil + } + eventName := stringValue{s: luacode.TagMethodEQ.String()} + f := l.metatable(v1).get(eventName) + if f == nil { + f = l.metatable(v2).get(eventName) + if f == nil { + // Neither value has an __eq metamethod. + return false, nil + } + } + result, err := l.call1(f, v1, v2) + if err != nil { + return false, err + } + return toBoolean(result), nil +} + func (l *State) push(x value) { if len(l.stack) == cap(l.stack) { panic(errStackOverflow) diff --git a/internal/mylua/lua_string.go b/internal/mylua/lua_string.go new file mode 100644 index 0000000..e2256d4 --- /dev/null +++ b/internal/mylua/lua_string.go @@ -0,0 +1,25 @@ +// Code generated by "stringer -type=ComparisonOperator -linecomment -output=lua_string.go"; DO NOT EDIT. + +package mylua + +import "strconv" + +func _() { + // An "invalid array index" compiler error signifies that the constant values have changed. + // Re-run the stringer command to generate them again. + var x [1]struct{} + _ = x[Equal-0] + _ = x[Less-1] + _ = x[LessOrEqual-2] +} + +const _ComparisonOperator_name = "==<<=" + +var _ComparisonOperator_index = [...]uint8{0, 2, 3, 5} + +func (i ComparisonOperator) String() string { + if i < 0 || i >= ComparisonOperator(len(_ComparisonOperator_index)-1) { + return "ComparisonOperator(" + strconv.FormatInt(int64(i), 10) + ")" + } + return _ComparisonOperator_name[_ComparisonOperator_index[i]:_ComparisonOperator_index[i+1]] +} diff --git a/internal/mylua/lua_test.go b/internal/mylua/lua_test.go index 663fa58..df7a708 100644 --- a/internal/mylua/lua_test.go +++ b/internal/mylua/lua_test.go @@ -8,12 +8,14 @@ import ( "errors" "io" "slices" + "strconv" "strings" "testing" "testing/iotest" "github.com/google/go-cmp/cmp" "zb.256lights.llc/pkg/internal/luacode" + "zb.256lights.llc/pkg/internal/lualex" ) func TestLoad(t *testing.T) { @@ -96,6 +98,139 @@ func TestLoad(t *testing.T) { }) } +func TestCompare(t *testing.T) { + type compareTable [3]int8 + const bad int8 = -1 + + tests := []struct { + name string + push func(l *State) + want compareTable + }{ + { + name: "StringNumber", + push: func(l *State) { + l.PushString("0") + l.PushInteger(0) + }, + want: compareTable{ + Equal: 0, + Less: bad, + LessOrEqual: bad, + }, + }, + { + name: "NumberString", + push: func(l *State) { + l.PushInteger(0) + l.PushString("0") + }, + want: compareTable{ + Equal: 0, + Less: bad, + LessOrEqual: bad, + }, + }, + } + + t.Run("StateMethod", func(t *testing.T) { + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + state := new(State) + defer func() { + if err := state.Close(); err != nil { + t.Error("Close:", err) + } + }() + + test.push(state) + s1 := describeValue(state, -2) + s2 := describeValue(state, -1) + + for opIndex, want := range test.want { + op := ComparisonOperator(opIndex) + got, err := state.Compare(-2, -1, op, 0) + if got != (want == 1) || (err != nil) != (want == bad) { + wantError := "" + if want == bad { + wantError = "" + } + t.Errorf("(%s %v %s) = %t, %v; want %t, %s", + s1, op, s2, got, err, (want == 1), wantError) + } + } + }) + } + }) + + t.Run("Load", func(t *testing.T) { + // Parse scripts for comparing two arguments. + scripts := [len(compareTable{})][]byte{} + for i := range scripts { + op := ComparisonOperator(i) + source := "local x, y = ...\nreturn x " + op.String() + " y\n" + proto, err := luacode.Parse(luacode.Source(source), strings.NewReader(source)) + if err != nil { + t.Fatal(err) + } + scripts[i], err = proto.MarshalBinary() + if err != nil { + t.Fatal(err) + } + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + state := new(State) + defer func() { + if err := state.Close(); err != nil { + t.Error("Close:", err) + } + }() + + test.push(state) + s1 := describeValue(state, -2) + s2 := describeValue(state, -1) + + for i, want := range test.want { + op := ComparisonOperator(i) + // TODO(now) + if err := state.Load(bytes.NewReader(scripts[i]), "", "b"); err != nil { + t.Error("Load:", err) + continue + } + + // Copy pushed values on top of function pushed. + state.PushValue(-3) + state.PushValue(-3) + + if err := state.Call(2, 1, 0); err != nil { + t.Logf("(%s %v %s): %v", s1, op, s2, err) + if want != bad { + t.Fail() + } + continue + } + if want == bad { + t.Fatalf("Comparison did not throw an error") + } + + if got, want := state.Type(-1), TypeBoolean; got != want { + t.Errorf("(%s %v %s) returned %v; want %v", + s1, op, s2, got, want) + } + got := state.ToBoolean(-1) + if got != (want == 1) { + t.Errorf("(%s %v %s) = %t, ; want %t, ", + s1, op, s2, got, (want == 1)) + } + state.Pop(1) + } + }) + } + }) +} + func TestRotate(t *testing.T) { tests := []struct { s []int @@ -138,3 +273,37 @@ func BenchmarkExec(b *testing.B) { state.Pop(1) } } + +func describeValue(l *State, idx int) string { + switch l.Type(idx) { + case TypeNone: + return "" + case TypeNil: + return "nil" + case TypeBoolean: + return strconv.FormatBool(l.ToBoolean(idx)) + case TypeString: + s, _ := l.ToString(idx) + return lualex.Quote(s) + case TypeNumber: + if l.IsInteger(idx) { + i, _ := l.ToInteger(idx) + return strconv.FormatInt(i, 10) + } + f, _ := l.ToNumber(idx) + return strconv.FormatFloat(f, 'g', 0, 64) + case TypeTable: + if l.RawLen(idx) == 0 { + return "{}" + } + return "{...}" + case TypeFunction: + return "" + case TypeLightUserdata, TypeUserdata: + return "" + case TypeThread: + return "" + default: + return "" + } +} diff --git a/internal/mylua/value.go b/internal/mylua/value.go index fbd273f..4c659ec 100644 --- a/internal/mylua/value.go +++ b/internal/mylua/value.go @@ -122,6 +122,13 @@ func exportNumericConstant(v value) (_ luacode.Value, ok bool) { // For [floatValue], a NaN is considered less than any non-NaN, // a NaN is considered equal to a NaN, // and -0.0 is equal to 0.0. +// +// This is a superset of the comparisons performed by [Lua relational operators] +// for the purpose of providing a total ordering for tables. +// +// If you only need to check for equality, [valuesEqual] is more efficient. +// +// [Lua relational operators]: https://www.lua.org/manual/5.4/manual.html#3.4.4 func compareValues(v1, v2 value) int { switch v1 := v1.(type) { case nil: @@ -178,6 +185,48 @@ func compareValues(v1, v2 value) int { } } +// valuesEqual reports whether v1 and v2 are [primitively equal] — +// that is, whether they are equal in Lua without consulting the “__eq” metamethod. +// This involves less comparisons than [compareValues]. +// +// [primitively equal]: https://www.lua.org/manual/5.4/manual.html#3.4.4 +func valuesEqual(v1, v2 value) bool { + switch v1 := v1.(type) { + case nil: + return v2 == nil + case booleanValue: + b2, ok := v2.(booleanValue) + return ok && v1 == b2 + case floatValue: + switch v2 := v2.(type) { + case integerValue: + i1, ok := v1.toInteger() + return ok && i1 == v2 + case floatValue: + return v1 == v2 + default: + return false + } + case integerValue: + switch v2 := v2.(type) { + case integerValue: + return v1 == v2 + case floatValue: + i2, ok := v2.toInteger() + return ok && v1 == i2 + default: + return false + } + case stringValue: + s2, ok := v2.(stringValue) + return ok && v1.s == s2.s + case *table, function: + return v1 == v2 + default: + panic("unhandled type") + } +} + // numericValue is an optional interface for types that implement [value] // and can be [coerced] to a number. // diff --git a/internal/mylua/vm.go b/internal/mylua/vm.go index af9a4eb..e670eb3 100644 --- a/internal/mylua/vm.go +++ b/internal/mylua/vm.go @@ -157,7 +157,7 @@ func (l *State) exec() (err error) { } nextPC := frame.pc + 1 - switch i.OpCode() { + switch opCode := i.OpCode(); opCode { case luacode.OpMove: r := registers() ra, err := register(r, i.ArgA()) @@ -470,7 +470,7 @@ func (l *State) exec() (err error) { } c := luacode.IntegerValue(int64(luacode.SignedArg(i.ArgC()))) if kb, isNumber := exportNumericConstant(*rb); isNumber { - op, ok := i.OpCode().ArithmeticOperator() + op, ok := opCode.ArithmeticOperator() if !ok { panic("operator should always be defined") } @@ -527,10 +527,10 @@ func (l *State) exec() (err error) { return err } if !kc.IsNumber() { - return fmt.Errorf("decode instruction (pc=%d): %v on non-numeric constant %v", frame.pc, i.OpCode(), kc) + return fmt.Errorf("decode instruction (pc=%d): %v on non-numeric constant %v", frame.pc, opCode, kc) } if rb, isNumber := exportNumericConstant(*rb); isNumber { - op, ok := i.OpCode().ArithmeticOperator() + op, ok := opCode.ArithmeticOperator() if !ok { panic("operator should always be defined") } @@ -569,7 +569,7 @@ func (l *State) exec() (err error) { } if kb, isNumber := exportNumericConstant(*rb); isNumber { if kc, isNumber := exportNumericConstant(*rc); isNumber { - op, ok := i.OpCode().ArithmeticOperator() + op, ok := opCode.ArithmeticOperator() if !ok { panic("operator should always be defined") } @@ -776,6 +776,107 @@ func (l *State) exec() (err error) { } case luacode.OpJMP: nextPC += int(i.J()) + case luacode.OpEQ: + r := registers() + ra, err := register(r, i.ArgA()) + if err != nil { + return err + } + rb, err := register(r, i.ArgB()) + if err != nil { + return err + } + result, err := l.equal(*ra, *rb) + if err != nil { + return err + } + if result != i.K() { + nextPC++ + } + case luacode.OpEQK: + ra, err := register(registers(), i.ArgA()) + if err != nil { + return err + } + kb, err := constant(uint32(i.ArgB())) + if err != nil { + return err + } + result, err := l.equal(*ra, importConstant(kb)) + if err != nil { + return err + } + if result != i.K() { + nextPC++ + } + case luacode.OpEQI: + ra, err := register(registers(), i.ArgA()) + if err != nil { + return err + } + result, err := l.equal(*ra, integerValue(luacode.SignedArg(i.ArgB()))) + if err != nil { + return err + } + if result != i.K() { + nextPC++ + } + case luacode.OpLT, luacode.OpLE: + r := registers() + ra, err := register(r, i.ArgA()) + if err != nil { + return err + } + rb, err := register(r, i.ArgB()) + if err != nil { + return err + } + op := Less + if opCode == luacode.OpLE { + op = LessOrEqual + } + result, err := l.compare(op, *ra, *rb) + if err != nil { + return err + } + if result != i.K() { + nextPC++ + } + case luacode.OpLTI, luacode.OpLEI: + ra, err := register(registers(), i.ArgA()) + if err != nil { + return err + } + op := Less + if opCode == luacode.OpLEI { + op = LessOrEqual + } + result, err := l.compare(op, *ra, integerValue(luacode.SignedArg(i.ArgB()))) + if err != nil { + return err + } + if result != i.K() { + nextPC++ + } + case luacode.OpGTI, luacode.OpGEI: + ra, err := register(registers(), i.ArgA()) + if err != nil { + return err + } + // According to the Lua reference manual, + // "A comparison a > b is translated to b < a and a >= b is translated to b <= a." + // https://www.lua.org/manual/5.4/manual.html#3.4.4 + op := Less + if opCode == luacode.OpGEI { + op = LessOrEqual + } + result, err := l.compare(op, integerValue(luacode.SignedArg(i.ArgB())), *ra) + if err != nil { + return err + } + if result != i.K() { + nextPC++ + } case luacode.OpTest: ra, err := register(registers(), i.ArgA()) if err != nil { @@ -930,7 +1031,7 @@ func (l *State) exec() (err error) { frame.numExtraArguments = numExtraArguments } default: - return fmt.Errorf("decode instruction (pc=%d): unhandled instruction %v", frame.pc, i.OpCode()) + return fmt.Errorf("decode instruction (pc=%d): unhandled instruction %v", frame.pc, opCode) } frame.pc = nextPC