From c6493215b5b910a61f9202967be7c92042d1945f Mon Sep 17 00:00:00 2001 From: Roxy Light Date: Thu, 13 Feb 2025 17:35:21 -0800 Subject: [PATCH] Optimize performance of Lua pattern-matching Skip exploring the same state more than once in the same iteration. This bounds work and memory usage. Some other small allocation optimizations: - Reuse capture lists where possible. - Preallocate a pool of pattern states. Added pattern-matching benchmarks to measure improvements. --- internal/lua/pattern.go | 122 ++++++++++++++++++++++++--------- internal/lua/stringlib_test.go | 101 +++++++++++++++++++++++++++ 2 files changed, 190 insertions(+), 33 deletions(-) diff --git a/internal/lua/pattern.go b/internal/lua/pattern.go index a6cebdf..d678085 100644 --- a/internal/lua/pattern.go +++ b/internal/lua/pattern.go @@ -25,6 +25,7 @@ const ( // [Lua pattern]: https://www.lua.org/manual/5.4/manual.html#6.4.1 type pattern struct { start patternState + numStates int numCaptures int positionCaptures sets.Bit } @@ -43,8 +44,36 @@ type patternState struct { } func parsePattern(p string) (*pattern, error) { - result := new(pattern) + result := &pattern{numStates: 1} + numSplits := 0 + p, anchored := strings.CutPrefix(p, "^") + estimatedStateSpace := len(p) + 1 // 1 extra for closing parenthesis. + if !anchored { + // Factor in "." and split states. + estimatedStateSpace += 2 + } + statePool := make([]patternState, estimatedStateSpace) + newPatternState := func(item string, out *patternState) *patternState { + result.numStates++ + if len(statePool) == 0 { + // If we underallocated, allocate more in small batches. + // 4 words * 32 = 256 bytes on 64-bit machines. + statePool = make([]patternState, 8) + } + s := &statePool[0] + s.item = item + s.out[0] = out + statePool = statePool[1:] + return s + } + newSplitState := func(out0, out1 *patternState) *patternState { + numSplits++ + s := newPatternState(patternStateSplit, out0) + s.out[1] = out1 + return s + } + var out [2]**patternState if anchored { result.start.item = "(" @@ -52,12 +81,10 @@ func parsePattern(p string) (*pattern, error) { } else { // Start with the equivalent of ".*(". result.start.item = patternStateSplit - startState := &patternState{item: "("} + numSplits++ + startState := newPatternState("(", nil) result.start.out[0] = startState - result.start.out[1] = &patternState{ - item: ".", - out: [2]*patternState{&result.start}, - } + result.start.out[1] = newPatternState(".", &result.start) out[0] = &startState.out[0] } @@ -81,7 +108,7 @@ func parsePattern(p string) (*pattern, error) { return nil, errors.New("too many captures") } captureDepth++ - newState := &patternState{item: "("} + newState := newPatternState("(", nil) patch(newState) out[0] = &newState.out[0] p = p[1:] @@ -90,12 +117,12 @@ func parsePattern(p string) (*pattern, error) { return nil, errors.New("invalid pattern capture") } captureDepth-- - newState := &patternState{item: ")"} + newState := newPatternState(")", nil) patch(newState) out[0] = &newState.out[0] p = p[1:] case p == "$": - newState := &patternState{item: patternStateSuffixAnchor} + newState := newPatternState(patternStateSuffixAnchor, nil) patch(newState) out[0] = &newState.out[0] p = p[1:] @@ -113,7 +140,7 @@ func parsePattern(p string) (*pattern, error) { return nil, err } n += len("%f") - newState := &patternState{item: p[:n]} + newState := newPatternState(p[:n], nil) patch(newState) out[0] = &newState.out[0] p = p[n:] @@ -123,7 +150,7 @@ func parsePattern(p string) (*pattern, error) { if err != nil { return nil, err } - newState := &patternState{item: p[:n]} + newState := newPatternState(p[:n], nil) var modifier byte if n < len(p) { @@ -131,10 +158,7 @@ func parsePattern(p string) (*pattern, error) { } switch modifier { case '?': - splitState := &patternState{ - item: patternStateSplit, - out: [2]*patternState{newState}, - } + splitState := newSplitState(newState, nil) patch(splitState) out[0] = &newState.out[0] out[1] = &splitState.out[1] @@ -142,24 +166,18 @@ func parsePattern(p string) (*pattern, error) { case '+': patch(newState) out[0] = &newState.out[0] - newState = &patternState{item: newState.item} + newState = newPatternState(newState.item, nil) fallthrough case '*': // Zero or more, prefer longer. - splitState := &patternState{ - item: patternStateSplit, - out: [2]*patternState{newState}, - } + splitState := newSplitState(newState, nil) newState.out[0] = splitState patch(splitState) out[0] = &splitState.out[1] p = p[n+1:] case '-': // Zero or more, prefer shorter. - splitState := &patternState{ - item: patternStateSplit, - out: [2]*patternState{nil, newState}, - } + splitState := newSplitState(nil, newState) newState.out[0] = splitState patch(splitState) out[0] = &splitState.out[0] @@ -175,8 +193,12 @@ func parsePattern(p string) (*pattern, error) { if captureDepth > 0 { return nil, errors.New("unfinished capture") } + if numSplits > 200 { + // Limit recursion depth in addState. + return nil, errors.New("pattern too complex") + } // Close out match. - patch(&patternState{item: ")"}) + patch(newPatternState(")", nil)) return result, nil } @@ -231,16 +253,32 @@ func (p *pattern) find(s string, pos int) []int { } capturesCap := (p.numCaptures + 1) * 2 - var currList, nextList []matchState + visited := make(sets.Set[*patternState], p.numStates) + currList := make([]matchState, 0, p.numStates) + nextList := make([]matchState, 0, p.numStates) + freeCaptures := make([][]int, 0, p.numStates) // Freelist of captures. var addState func(matchState) addState = func(curr matchState) { // Advance past zero-length states. for { + // A terminal state needs no further processing and is always added. + if curr.state == nil { + nextList = append(nextList, curr) + return + } + + // A state can appear at most once in a list. + if visited.Has(curr.state) { + freeCaptures = append(freeCaptures, curr.captures) + return + } + visited.Add(curr.state) + switch { - case curr.state != nil && curr.state.item == "(": + case curr.state.item == "(": curr.captures = append(curr.captures, pos, -1) curr.state = curr.state.out[0] - case curr.state != nil && curr.state.item == ")": + case curr.state.item == ")": // Fill in the end index of the most recently opened capture. i := lastIndex(curr.captures, -1) if i == -1 { @@ -248,21 +286,35 @@ func (p *pattern) find(s string, pos int) []int { } curr.captures[i] = pos curr.state = curr.state.out[0] - case curr.state != nil && curr.state.item == patternStateSplit: - capturesCopy := append(make([]int, 0, capturesCap), curr.captures...) - // TODO(soon): Remove recursive call or check depth. + case curr.state.item == patternStateSplit: + // Clone the captures from the current state. + // Reuse captures from previously discarded states if any. + var capturesCopy []int + if len(freeCaptures) == 0 { + capturesCopy = make([]int, 0, capturesCap) + } else { + i := len(freeCaptures) - 1 + capturesCopy = freeCaptures[i][:0] + freeCaptures[i] = nil + freeCaptures = freeCaptures[:i] + } + capturesCopy = append(capturesCopy, curr.captures...) + + // Recursive call bounded by number of splits in the pattern. + // [parsePattern] performs a hard limit. addState(matchState{ state: curr.state.out[0], captures: curr.captures, }) curr.captures = capturesCopy curr.state = curr.state.out[1] - case curr.state != nil && curr.state.item == patternStateSuffixAnchor: + case curr.state.item == patternStateSuffixAnchor: if pos < len(s) { + freeCaptures = append(freeCaptures, curr.captures) return } curr.state = curr.state.out[0] - case curr.state != nil && strings.HasPrefix(curr.state.item, "%f["): + case strings.HasPrefix(curr.state.item, "%f["): set := curr.state.item[len("%f[") : len(curr.state.item)-1] var prev, next byte if pos > 0 { @@ -272,6 +324,7 @@ func (p *pattern) find(s string, pos int) []int { next = s[pos] } if matchBracketClass(prev, set) || !matchBracketClass(next, set) { + freeCaptures = append(freeCaptures, curr.captures) return } curr.state = curr.state.out[0] @@ -296,6 +349,7 @@ func (p *pattern) find(s string, pos int) []int { return currList[0].captures } + clear(visited) clear(nextList) nextList = nextList[:0] c := s[pos] @@ -310,6 +364,8 @@ func (p *pattern) find(s string, pos int) []int { state: curr.state.out[0], captures: curr.captures, }) + } else { + freeCaptures = append(freeCaptures, curr.captures) } } } diff --git a/internal/lua/stringlib_test.go b/internal/lua/stringlib_test.go index a2733e7..945a81f 100644 --- a/internal/lua/stringlib_test.go +++ b/internal/lua/stringlib_test.go @@ -117,6 +117,12 @@ func TestStringFind(t *testing.T) { init: 1, want: []any{int64(1), int64(0)}, }, + { + s: "abc", + pattern: "", + init: 1, + want: []any{int64(1), int64(0)}, + }, { s: "aaa", pattern: "^a", @@ -572,6 +578,101 @@ func TestStringFind(t *testing.T) { } } +func BenchmarkStringFind(b *testing.B) { + ctx := context.Background() + state := new(State) + defer func() { + if err := state.Close(); err != nil { + b.Error("Close:", err) + } + }() + + state.PushClosure(0, OpenString) + if err := state.Call(ctx, 0, 1); err != nil { + b.Fatal(err) + } + if _, err := state.Field(ctx, -1, "find"); err != nil { + b.Fatal(err) + } + + benchmarks := []struct { + name string + s string + pattern string + wantStart int64 + wantEnd int64 + }{ + { + name: "SingleByte", + s: "abc", + pattern: "b", + wantStart: 2, + wantEnd: 2, + }, + { + name: "Word", + s: "aaabbbccc", + pattern: "bbb", + wantStart: 4, + wantEnd: 6, + }, + { + name: "SpaceSeparatedFields", + s: "foo bar baz quux xyzzy", + pattern: ".* .* .* .* .*", + wantStart: 1, + wantEnd: 22, + }, + { + name: "SpaceSeparatedCaptures", + s: "foo bar baz quux xyzzy", + pattern: "(.*) (.*) (.*) (.*) (.*)", + wantStart: 1, + wantEnd: 22, + }, + // Test case presented in introduction of https://swtch.com/~rsc/regexp/regexp1.html + { + name: "WorstCase", + s: strings.Repeat("a", 30), + pattern: strings.Repeat("a?", 30) + strings.Repeat("a", 30), + wantStart: 1, + wantEnd: 30, + }, + } + + for _, benchmark := range benchmarks { + b.Run(benchmark.name, func(b *testing.B) { + defer state.SetTop(state.Top()) + + b.SetBytes(int64(len(benchmark.s))) + for range b.N { + state.PushValue(-1) + state.PushString(benchmark.s) + state.PushString(benchmark.pattern) + if err := state.Call(ctx, 2, 2); err != nil { + b.Fatal(err) + } + + start, startOK := state.ToInteger(-2) + end, endOK := state.ToInteger(-1) + if !startOK { + b.Errorf("type(select(1, string.find(%s, %s))) = %v; want integer", + lualex.Quote(benchmark.s), lualex.Quote(benchmark.pattern), state.Type(-2)) + } + if !endOK { + b.Errorf("type(select(2, string.find(%s, %s))) = %v; want integer", + lualex.Quote(benchmark.s), lualex.Quote(benchmark.pattern), state.Type(-1)) + } + if startOK && endOK && (start != benchmark.wantStart || end != benchmark.wantEnd) { + b.Errorf("string.find(%s, %s) = %d, %d; want %d, %d", + lualex.Quote(benchmark.s), lualex.Quote(benchmark.pattern), start, end, benchmark.wantStart, benchmark.wantEnd) + } + state.Pop(2) + } + }) + } +} + func TestStringMatch(t *testing.T) { tests := []struct { s string