Skip to content

Commit

Permalink
Implement relational operators
Browse files Browse the repository at this point in the history
  • Loading branch information
zombiezen committed Dec 14, 2024
1 parent c2476dc commit 808a2f3
Show file tree
Hide file tree
Showing 5 changed files with 470 additions and 6 deletions.
120 changes: 120 additions & 0 deletions internal/mylua/lua.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
// Copyright 2024 The zb Authors
// SPDX-License-Identifier: MIT

//go:generate stringer -type=ComparisonOperator -linecomment -output=lua_string.go

package mylua

import (
Expand Down Expand Up @@ -554,6 +556,124 @@ func (l *State) RawLen(idx int) uint64 {
return uint64(lv.len())
}

// RawEqual reports whether the two values in the given indices
// are primitively equal (that is, equal without calling the __eq metamethod).
// If either index is invalid, then RawEqual reports false.
func (l *State) RawEqual(idx1, idx2 int) bool {
l.init()
v1, _, err := l.valueByIndex(idx1)
if err != nil {
return false
}
v2, _, err := l.valueByIndex(idx2)
if err != nil {
return false
}
return valuesEqual(v1, v2)
}

// ComparisonOperator is an enumeration of operators
// that can be used with [*State.Compare].
type ComparisonOperator int

// Defined [ComparisonOperator] values.
const (
Equal ComparisonOperator = iota // ==
Less // <
LessOrEqual // <=
)

func (l *State) Compare(idx1, idx2 int, op ComparisonOperator, msgHandler int) (bool, error) {
l.init()
v1, _, err := l.valueByIndex(idx1)
if err != nil {
return false, err
}
v2, _, err := l.valueByIndex(idx2)
if err != nil {
return false, err
}
if msgHandler != 0 {
return false, fmt.Errorf("TODO(someday): support message handlers")
}
return l.compare(op, v1, v2)
}

// compare returns the result of comparing v1 and v2 with the given operator
// according to Lua's full comparison rules (including metamethods).
func (l *State) compare(op ComparisonOperator, v1, v2 value) (bool, error) {
switch op {
case Equal:
return l.equal(v1, v2)
case Less, LessOrEqual:
t1, t2 := valueType(v1), valueType(v2)
if t1 == TypeNumber && t2 == TypeNumber || t1 == TypeString && t2 == TypeString {
result := compareValues(v1, v2)
return result < 0 || result == 0 && op == LessOrEqual, nil
}
var eventName stringValue
switch op {
case Less:
eventName = stringValue{s: luacode.TagMethodLT.String()}
case LessOrEqual:
eventName = stringValue{s: luacode.TagMethodLE.String()}
default:
panic("unreachable")
}
f := l.metatable(v1).get(eventName)
if f == nil {
f = l.metatable(v2).get(eventName)
if f == nil {
// Neither value has the needed metamethod.
tn1 := l.typeName(v1)
tn2 := l.typeName(v2)
if tn1 == tn2 {
return false, fmt.Errorf("attempt to compare two %s values", tn1)
}
return false, fmt.Errorf("attempt to compare %s with %s", tn1, tn2)
}
}
result, err := l.call1(f, v1, v2)
if err != nil {
return false, err
}
return toBoolean(result), nil
default:
return false, fmt.Errorf("invalid %v", op)
}
}

// equal reports whether v1 == v2 according to Lua's full equality rules (including metamethods).
func (l *State) equal(v1, v2 value) (bool, error) {
// Values of different types are never equal.
t1, t2 := valueType(v1), valueType(v2)
if t1 != t2 {
return false, nil
}
// If the values are primitively equal, then it's equal.
if valuesEqual(v1, v2) {
return true, nil
}
// Check __eq metamethod for types with individual metatables.
if !(t1 == TypeTable || t1 == TypeUserdata) {
return false, nil
}
eventName := stringValue{s: luacode.TagMethodEQ.String()}
f := l.metatable(v1).get(eventName)
if f == nil {
f = l.metatable(v2).get(eventName)
if f == nil {
// Neither value has an __eq metamethod.
return false, nil
}
}
result, err := l.call1(f, v1, v2)
if err != nil {
return false, err
}
return toBoolean(result), nil
}

func (l *State) push(x value) {
if len(l.stack) == cap(l.stack) {
panic(errStackOverflow)
Expand Down
25 changes: 25 additions & 0 deletions internal/mylua/lua_string.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

169 changes: 169 additions & 0 deletions internal/mylua/lua_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,14 @@ import (
"errors"
"io"
"slices"
"strconv"
"strings"
"testing"
"testing/iotest"

"github.com/google/go-cmp/cmp"
"zb.256lights.llc/pkg/internal/luacode"
"zb.256lights.llc/pkg/internal/lualex"
)

func TestLoad(t *testing.T) {
Expand Down Expand Up @@ -96,6 +98,139 @@ func TestLoad(t *testing.T) {
})
}

func TestCompare(t *testing.T) {
type compareTable [3]int8
const bad int8 = -1

tests := []struct {
name string
push func(l *State)
want compareTable
}{
{
name: "StringNumber",
push: func(l *State) {
l.PushString("0")
l.PushInteger(0)
},
want: compareTable{
Equal: 0,
Less: bad,
LessOrEqual: bad,
},
},
{
name: "NumberString",
push: func(l *State) {
l.PushInteger(0)
l.PushString("0")
},
want: compareTable{
Equal: 0,
Less: bad,
LessOrEqual: bad,
},
},
}

t.Run("StateMethod", func(t *testing.T) {
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
state := new(State)
defer func() {
if err := state.Close(); err != nil {
t.Error("Close:", err)
}
}()

test.push(state)
s1 := describeValue(state, -2)
s2 := describeValue(state, -1)

for opIndex, want := range test.want {
op := ComparisonOperator(opIndex)
got, err := state.Compare(-2, -1, op, 0)
if got != (want == 1) || (err != nil) != (want == bad) {
wantError := "<nil>"
if want == bad {
wantError = "<error>"
}
t.Errorf("(%s %v %s) = %t, %v; want %t, %s",
s1, op, s2, got, err, (want == 1), wantError)
}
}
})
}
})

t.Run("Load", func(t *testing.T) {
// Parse scripts for comparing two arguments.
scripts := [len(compareTable{})][]byte{}
for i := range scripts {
op := ComparisonOperator(i)
source := "local x, y = ...\nreturn x " + op.String() + " y\n"
proto, err := luacode.Parse(luacode.Source(source), strings.NewReader(source))
if err != nil {
t.Fatal(err)
}
scripts[i], err = proto.MarshalBinary()
if err != nil {
t.Fatal(err)
}
}

for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
state := new(State)
defer func() {
if err := state.Close(); err != nil {
t.Error("Close:", err)
}
}()

test.push(state)
s1 := describeValue(state, -2)
s2 := describeValue(state, -1)

for i, want := range test.want {
op := ComparisonOperator(i)
// TODO(now)
if err := state.Load(bytes.NewReader(scripts[i]), "", "b"); err != nil {
t.Error("Load:", err)
continue
}

// Copy pushed values on top of function pushed.
state.PushValue(-3)
state.PushValue(-3)

if err := state.Call(2, 1, 0); err != nil {
t.Logf("(%s %v %s): %v", s1, op, s2, err)
if want != bad {
t.Fail()
}
continue
}
if want == bad {
t.Fatalf("Comparison did not throw an error")
}

if got, want := state.Type(-1), TypeBoolean; got != want {
t.Errorf("(%s %v %s) returned %v; want %v",
s1, op, s2, got, want)
}
got := state.ToBoolean(-1)
if got != (want == 1) {
t.Errorf("(%s %v %s) = %t, <nil>; want %t, <nil>",
s1, op, s2, got, (want == 1))
}
state.Pop(1)
}
})
}
})
}

func TestRotate(t *testing.T) {
tests := []struct {
s []int
Expand Down Expand Up @@ -138,3 +273,37 @@ func BenchmarkExec(b *testing.B) {
state.Pop(1)
}
}

func describeValue(l *State, idx int) string {
switch l.Type(idx) {
case TypeNone:
return "<none>"
case TypeNil:
return "nil"
case TypeBoolean:
return strconv.FormatBool(l.ToBoolean(idx))
case TypeString:
s, _ := l.ToString(idx)
return lualex.Quote(s)
case TypeNumber:
if l.IsInteger(idx) {
i, _ := l.ToInteger(idx)
return strconv.FormatInt(i, 10)
}
f, _ := l.ToNumber(idx)
return strconv.FormatFloat(f, 'g', 0, 64)
case TypeTable:
if l.RawLen(idx) == 0 {
return "{}"
}
return "{...}"
case TypeFunction:
return "<function>"
case TypeLightUserdata, TypeUserdata:
return "<userdata>"
case TypeThread:
return "<thread>"
default:
return "<unknown>"
}
}
Loading

0 comments on commit 808a2f3

Please sign in to comment.