Skip to content

Commit

Permalink
Add test for upvalues with function closure
Browse files Browse the repository at this point in the history
  • Loading branch information
zombiezen committed Dec 16, 2024
1 parent a0ea228 commit 0803dc8
Show file tree
Hide file tree
Showing 2 changed files with 46 additions and 3 deletions.
6 changes: 3 additions & 3 deletions internal/mylua/functions.go
Original file line number Diff line number Diff line change
Expand Up @@ -158,14 +158,14 @@ func (l *State) checkUpvalues(upvalues []*upvalue) error {
}

// closeUpvalues moves the values of any upvalues
// that refer to stack values at indices less than top
// that refer to stack values at indices greater than or equal to top
// off to the stack, thus “closing” them.
// This is distinct from calling the “__close” metamethods,
// but often happens at the same time.
func (l *State) closeUpvalues(top int) {
func (l *State) closeUpvalues(bottom int) {
n := 0
for _, uv := range l.pendingVariables {
if uv.isOpen() && uv.stackIndex >= top {
if uv.isOpen() && uv.stackIndex >= bottom {
// Close the upvalue.
uv.storage = l.stack[uv.stackIndex]
uv.stackIndex = -1
Expand Down
43 changes: 43 additions & 0 deletions internal/mylua/vm_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -369,4 +369,47 @@ func TestVM(t *testing.T) {
t.Errorf("emit sequence = %v; want %v", got, want)
}
})

t.Run("Upvalues", func(t *testing.T) {
state := new(State)
defer func() {
if err := state.Close(); err != nil {
t.Error("Close:", err)
}
}()

var got []int64
state.PushClosure(0, func(state *State) (int, error) {
state.SetTop(1)
i, ok := state.ToInteger(1)
if !ok {
t.Errorf("on call %d, emit received %v", len(got)+1, state.Type(1))
}
got = append(got, i)
return 0, nil
})
if err := state.SetGlobal("emit", 0); err != nil {
t.Fatal(err)
}

const source = `local function counter()` + "\n" +
`local x = 1` + "\n" +
`return function() x = x + 1; return x - 1 end` + "\n" +
`end` + "\n" +
`local c = counter()` + "\n" +
`emit(c())` + "\n" +
`emit(c())` + "\n" +
`emit(c())` + "\n"
if err := state.Load(strings.NewReader(source), luacode.Source(source), "t"); err != nil {
t.Fatal(err)
}
if err := state.Call(0, 0, 0); err != nil {
t.Fatal(err)
}

want := []int64{1, 2, 3}
if !slices.Equal(want, got) {
t.Errorf("emit sequence = %v; want %v", got, want)
}
})
}

0 comments on commit 0803dc8

Please sign in to comment.