Skip to content

Commit

Permalink
Implement string concatenation
Browse files Browse the repository at this point in the history
  • Loading branch information
zombiezen committed Dec 13, 2024
1 parent a99b969 commit 2795619
Show file tree
Hide file tree
Showing 4 changed files with 252 additions and 5 deletions.
158 changes: 153 additions & 5 deletions internal/mylua/lua.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import (
"fmt"
"io"
"slices"
"strings"

"zb.256lights.llc/pkg/internal/luacode"
"zb.256lights.llc/pkg/sets"
Expand Down Expand Up @@ -506,11 +507,8 @@ func (l *State) ToString(idx int) (s string, ok bool) {
switch v := (*p).(type) {
case stringValue:
return v.s, true
case integerValue:
s, _ := luacode.IntegerValue(int64(v)).Unquoted()
return s, true
case floatValue:
s, _ := luacode.FloatValue(float64(v)).Unquoted()
case valueStringer:
*p = v.stringValue()
return s, true
default:
return "", false
Expand Down Expand Up @@ -1243,6 +1241,156 @@ func (l *State) Load(r io.Reader, chunkName luacode.Source, mode string) (err er
return nil
}

// Concat concatenates the n values at the top of the stack, pops them,
// and leaves the result on the top.
// If n is 1, the result is the single value on the stack
// (that is, the function does nothing);
// if n is 0, the result is the empty string.
// Concatenation is performed following the usual semantics of Lua.
//
// If there is any error, Concat catches it,
// leaves nil or the error object (see Error Handling in [State]) on the top of the stack,
// and returns an error.
func (l *State) Concat(n, msgHandler int) error {
if n < 0 {
return errors.New("lua concat: negative argument length")
}
if n > l.Top() {
return errors.New("lua concat: not enough arguments on the stack")
}
if msgHandler != 0 {
return fmt.Errorf("TODO(someday): support message handlers")
}

if err := l.concat(n); err != nil {
l.push(nil)
return err
}
return nil
}

func (l *State) concat(n int) error {
if n == 0 {
l.push(stringValue{})
return nil
}
firstArg := len(l.stack) - n
if firstArg < l.frame().registerStart() {
return errors.New("concat: stack underflow")
}

isEmptyString := func(v value) bool {
sv, ok := v.(stringValue)
return ok && sv.isEmpty()
}

for len(l.stack) > firstArg+1 {
v1 := l.stack[len(l.stack)-2]
_, isStringer1 := v1.(valueStringer)
v2 := l.stack[len(l.stack)-1]
vs2, isStringer2 := v2.(valueStringer)
switch {
case !isStringer1 || !isStringer2:
if err := l.concatMetamethod(); err != nil {
l.setTop(firstArg)
return err
}
case isEmptyString(v1):
l.stack[len(l.stack)-2] = vs2.stringValue()
fallthrough
case isEmptyString(v2):
l.setTop(len(l.stack) - 1)
default:
// The end of the slice has two or more non-empty strings.
// Find the longest run of values that can be coerced to a string,
// and perform raw string concatenation.
concatStart := firstArg + stringerTailStart(l.stack[firstArg:len(l.stack)-2])
initialCapacity, hasContext := minConcatSize(l.stack[concatStart:])
sb := new(strings.Builder)
sb.Grow(initialCapacity)
var sctx sets.Set[string]
if hasContext {
sctx = make(sets.Set[string])
}

for _, v := range l.stack[concatStart:] {
sv := v.(valueStringer).stringValue()
sb.WriteString(sv.s)
sctx.AddSeq(sv.context.All())
}

l.stack[concatStart] = stringValue{
s: sb.String(),
context: sctx,
}
l.setTop(concatStart + 1)
}
}
return nil
}

// concatMetamethod attempts to call the __concat metamethod
// with the two values on the top of the stack.
func (l *State) concatMetamethod() error {
arg1 := l.stack[len(l.stack)-2]
arg2 := l.stack[len(l.stack)-1]
eventName := stringValue{s: luacode.TagMethodConcat.String()}
f := l.metatable(arg1).get(eventName)
if f == nil {
f = l.metatable(arg2).get(eventName)
if f == nil {
badArg := arg1
if _, isStringer := badArg.(valueStringer); isStringer {
badArg = arg2
}
return fmt.Errorf("attempt to concatenate a %s value", l.typeName(badArg))
}
}

// Insert metamethod before two arguments.
l.push(f)
rotate(l.stack[len(l.stack)-3:], 1)

// Call metamethod.
isLua, err := l.prepareCall(2, 1)
if err != nil {
return err
}
if isLua {
if err := l.exec(); err != nil {
return err
}
}
return nil
}

// stringerTailStart returns the first index i
// where every element of values[i:] implements [valueStringer].
func stringerTailStart(values []value) int {
for ; len(values) > 0; values = values[:len(values)-1] {
_, isStringer := values[len(values)-1].(valueStringer)
if !isStringer {
break
}
}
return len(values)
}

// minConcatSize returns the minimum buffer size necessary
// to concatenate the given values.
func minConcatSize(values []value) (n int, hasContext bool) {
for _, v := range values {
if sv, ok := v.(stringValue); ok {
n += len(sv.s)
hasContext = hasContext || len(sv.context) > 0
} else {
// Numbers are non-empty, so add 1.
n++
}
}
return
}

// Len pushes the length of the value at the given index to the stack.
// It is equivalent to the ['#' operator in Lua]
// and may trigger a [metamethod] for the "length" event.
Expand Down
36 changes: 36 additions & 0 deletions internal/mylua/value.go
Original file line number Diff line number Diff line change
Expand Up @@ -218,6 +218,24 @@ func toBoolean(v value) bool {
}
}

type valueStringer interface {
stringValue() stringValue
}

var (
_ valueStringer = floatValue(0)
_ valueStringer = integerValue(0)
_ valueStringer = stringValue{}
)

func toString(v value) (_ stringValue, ok bool) {
sv, ok := v.(valueStringer)
if !ok {
return stringValue{}, false
}
return sv.stringValue(), true
}

// lenValue is a [value] that has a defined "raw" length.
type lenValue interface {
value
Expand Down Expand Up @@ -373,6 +391,11 @@ func (v integerValue) valueType() Type { return TypeNumber }
func (v integerValue) toNumber() (floatValue, bool) { return floatValue(v), true }
func (v integerValue) toInteger() (integerValue, bool) { return v, true }

func (v integerValue) stringValue() stringValue {
s, _ := luacode.IntegerValue(int64(v)).Unquoted()
return stringValue{s: s}
}

// floatValue is a floating-point [value].
type floatValue float64

Expand All @@ -384,6 +407,11 @@ func (v floatValue) toInteger() (integerValue, bool) {
return integerValue(i), ok
}

func (v floatValue) stringValue() stringValue {
s, _ := luacode.FloatValue(float64(v)).Unquoted()
return stringValue{s: s}
}

// stringValue is a string [value].
// stringValues implement [numericValue] because they can be coerced to numbers.
//
Expand All @@ -403,6 +431,14 @@ func (v stringValue) len() integerValue {
return integerValue(len(v.s))
}

func (v stringValue) isEmpty() bool {
return len(v.s) == 0 && len(v.context) == 0
}

func (v stringValue) stringValue() stringValue {
return v
}

func (v stringValue) toNumber() (floatValue, bool) {
f, err := lualex.ParseNumber(v.s)
if err != nil {
Expand Down
11 changes: 11 additions & 0 deletions internal/mylua/vm.go
Original file line number Diff line number Diff line change
Expand Up @@ -763,6 +763,17 @@ func (l *State) exec() (err error) {
return err
}
*ra = result
case luacode.OpConcat:
a, b := i.ArgA(), i.ArgB()
top := int(a) + int(b)
if top > int(f.proto.MaxStackSize) {
return fmt.Errorf("decode instruction (pc=%d): concat: register %d out-of-bounds (stack is %d slots)",
frame.pc, top-1, f.proto.MaxStackSize)
}
l.setTop(frame.registerStart() + top)
if err := l.concat(int(b)); err != nil {
return err
}
case luacode.OpJMP:
nextPC += int(i.J())
case luacode.OpTest:
Expand Down
52 changes: 52 additions & 0 deletions internal/mylua/vm_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -214,4 +214,56 @@ func TestVM(t *testing.T) {
t.Errorf("state.ToInteger(-1) = %d, %t; want %d, true", got, ok, want)
}
})

t.Run("Concat2", func(t *testing.T) {
state := new(State)
defer func() {
if err := state.Close(); err != nil {
t.Error("Close:", err)
}
}()

state.PushString("World")
if err := state.SetGlobal("x", 0); err != nil {
t.Fatal(err)
}

const source = `return "Hello, "..x`
if err := state.Load(strings.NewReader(source), luacode.Source(source), "t"); err != nil {
t.Fatal(err)
}
if err := state.Call(0, 1, 0); err != nil {
t.Fatal(err)
}
const want = "Hello, World"
if got, ok := state.ToString(-1); got != want || !ok {
t.Errorf("state.ToString(-1) = %q, %t; want %q, true", got, ok, want)
}
})

t.Run("Concat3", func(t *testing.T) {
state := new(State)
defer func() {
if err := state.Close(); err != nil {
t.Error("Close:", err)
}
}()

state.PushString("World")
if err := state.SetGlobal("x", 0); err != nil {
t.Fatal(err)
}

const source = `return "Hello, "..x.."!"`
if err := state.Load(strings.NewReader(source), luacode.Source(source), "t"); err != nil {
t.Fatal(err)
}
if err := state.Call(0, 1, 0); err != nil {
t.Fatal(err)
}
const want = "Hello, World!"
if got, ok := state.ToString(-1); got != want || !ok {
t.Errorf("state.ToString(-1) = %q, %t; want %q, true", got, ok, want)
}
})
}

0 comments on commit 2795619

Please sign in to comment.