Skip to content

Commit

Permalink
Optimize performance of Lua pattern-matching
Browse files Browse the repository at this point in the history
Skip exploring the same state more than once in the same iteration.
This bounds work and memory usage.

Some other small allocation optimizations:

- Reuse capture lists where possible.
- Preallocate a pool of pattern states.

Added pattern-matching benchmarks to measure improvements.
  • Loading branch information
zombiezen committed Feb 14, 2025
1 parent f297903 commit c649321
Show file tree
Hide file tree
Showing 2 changed files with 190 additions and 33 deletions.
122 changes: 89 additions & 33 deletions internal/lua/pattern.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ const (
// [Lua pattern]: https://www.lua.org/manual/5.4/manual.html#6.4.1
type pattern struct {
start patternState
numStates int
numCaptures int
positionCaptures sets.Bit
}
Expand All @@ -43,21 +44,47 @@ type patternState struct {
}

func parsePattern(p string) (*pattern, error) {
result := new(pattern)
result := &pattern{numStates: 1}
numSplits := 0

p, anchored := strings.CutPrefix(p, "^")
estimatedStateSpace := len(p) + 1 // 1 extra for closing parenthesis.
if !anchored {
// Factor in "." and split states.
estimatedStateSpace += 2
}
statePool := make([]patternState, estimatedStateSpace)
newPatternState := func(item string, out *patternState) *patternState {
result.numStates++
if len(statePool) == 0 {
// If we underallocated, allocate more in small batches.
// 4 words * 32 = 256 bytes on 64-bit machines.
statePool = make([]patternState, 8)
}
s := &statePool[0]
s.item = item
s.out[0] = out
statePool = statePool[1:]
return s
}
newSplitState := func(out0, out1 *patternState) *patternState {
numSplits++
s := newPatternState(patternStateSplit, out0)
s.out[1] = out1
return s
}

var out [2]**patternState
if anchored {
result.start.item = "("
out[0] = &result.start.out[0]
} else {
// Start with the equivalent of ".*(".
result.start.item = patternStateSplit
startState := &patternState{item: "("}
numSplits++
startState := newPatternState("(", nil)
result.start.out[0] = startState
result.start.out[1] = &patternState{
item: ".",
out: [2]*patternState{&result.start},
}
result.start.out[1] = newPatternState(".", &result.start)
out[0] = &startState.out[0]
}

Expand All @@ -81,7 +108,7 @@ func parsePattern(p string) (*pattern, error) {
return nil, errors.New("too many captures")
}
captureDepth++
newState := &patternState{item: "("}
newState := newPatternState("(", nil)
patch(newState)
out[0] = &newState.out[0]
p = p[1:]
Expand All @@ -90,12 +117,12 @@ func parsePattern(p string) (*pattern, error) {
return nil, errors.New("invalid pattern capture")
}
captureDepth--
newState := &patternState{item: ")"}
newState := newPatternState(")", nil)
patch(newState)
out[0] = &newState.out[0]
p = p[1:]
case p == "$":
newState := &patternState{item: patternStateSuffixAnchor}
newState := newPatternState(patternStateSuffixAnchor, nil)
patch(newState)
out[0] = &newState.out[0]
p = p[1:]
Expand All @@ -113,7 +140,7 @@ func parsePattern(p string) (*pattern, error) {
return nil, err
}
n += len("%f")
newState := &patternState{item: p[:n]}
newState := newPatternState(p[:n], nil)
patch(newState)
out[0] = &newState.out[0]
p = p[n:]
Expand All @@ -123,43 +150,34 @@ func parsePattern(p string) (*pattern, error) {
if err != nil {
return nil, err
}
newState := &patternState{item: p[:n]}
newState := newPatternState(p[:n], nil)

var modifier byte
if n < len(p) {
modifier = p[n]
}
switch modifier {
case '?':
splitState := &patternState{
item: patternStateSplit,
out: [2]*patternState{newState},
}
splitState := newSplitState(newState, nil)
patch(splitState)
out[0] = &newState.out[0]
out[1] = &splitState.out[1]
p = p[n+1:]
case '+':
patch(newState)
out[0] = &newState.out[0]
newState = &patternState{item: newState.item}
newState = newPatternState(newState.item, nil)
fallthrough
case '*':
// Zero or more, prefer longer.
splitState := &patternState{
item: patternStateSplit,
out: [2]*patternState{newState},
}
splitState := newSplitState(newState, nil)
newState.out[0] = splitState
patch(splitState)
out[0] = &splitState.out[1]
p = p[n+1:]
case '-':
// Zero or more, prefer shorter.
splitState := &patternState{
item: patternStateSplit,
out: [2]*patternState{nil, newState},
}
splitState := newSplitState(nil, newState)
newState.out[0] = splitState
patch(splitState)
out[0] = &splitState.out[0]
Expand All @@ -175,8 +193,12 @@ func parsePattern(p string) (*pattern, error) {
if captureDepth > 0 {
return nil, errors.New("unfinished capture")
}
if numSplits > 200 {
// Limit recursion depth in addState.
return nil, errors.New("pattern too complex")
}
// Close out match.
patch(&patternState{item: ")"})
patch(newPatternState(")", nil))

return result, nil
}
Expand Down Expand Up @@ -231,38 +253,68 @@ func (p *pattern) find(s string, pos int) []int {
}

capturesCap := (p.numCaptures + 1) * 2
var currList, nextList []matchState
visited := make(sets.Set[*patternState], p.numStates)
currList := make([]matchState, 0, p.numStates)
nextList := make([]matchState, 0, p.numStates)
freeCaptures := make([][]int, 0, p.numStates) // Freelist of captures.
var addState func(matchState)
addState = func(curr matchState) {
// Advance past zero-length states.
for {
// A terminal state needs no further processing and is always added.
if curr.state == nil {
nextList = append(nextList, curr)
return
}

// A state can appear at most once in a list.
if visited.Has(curr.state) {
freeCaptures = append(freeCaptures, curr.captures)
return
}
visited.Add(curr.state)

switch {
case curr.state != nil && curr.state.item == "(":
case curr.state.item == "(":
curr.captures = append(curr.captures, pos, -1)
curr.state = curr.state.out[0]
case curr.state != nil && curr.state.item == ")":
case curr.state.item == ")":
// Fill in the end index of the most recently opened capture.
i := lastIndex(curr.captures, -1)
if i == -1 {
panic("unmatched parenthesis")
}
curr.captures[i] = pos
curr.state = curr.state.out[0]
case curr.state != nil && curr.state.item == patternStateSplit:
capturesCopy := append(make([]int, 0, capturesCap), curr.captures...)
// TODO(soon): Remove recursive call or check depth.
case curr.state.item == patternStateSplit:
// Clone the captures from the current state.
// Reuse captures from previously discarded states if any.
var capturesCopy []int
if len(freeCaptures) == 0 {
capturesCopy = make([]int, 0, capturesCap)
} else {
i := len(freeCaptures) - 1
capturesCopy = freeCaptures[i][:0]
freeCaptures[i] = nil
freeCaptures = freeCaptures[:i]
}
capturesCopy = append(capturesCopy, curr.captures...)

// Recursive call bounded by number of splits in the pattern.
// [parsePattern] performs a hard limit.
addState(matchState{
state: curr.state.out[0],
captures: curr.captures,
})
curr.captures = capturesCopy
curr.state = curr.state.out[1]
case curr.state != nil && curr.state.item == patternStateSuffixAnchor:
case curr.state.item == patternStateSuffixAnchor:
if pos < len(s) {
freeCaptures = append(freeCaptures, curr.captures)
return
}
curr.state = curr.state.out[0]
case curr.state != nil && strings.HasPrefix(curr.state.item, "%f["):
case strings.HasPrefix(curr.state.item, "%f["):
set := curr.state.item[len("%f[") : len(curr.state.item)-1]
var prev, next byte
if pos > 0 {
Expand All @@ -272,6 +324,7 @@ func (p *pattern) find(s string, pos int) []int {
next = s[pos]
}
if matchBracketClass(prev, set) || !matchBracketClass(next, set) {
freeCaptures = append(freeCaptures, curr.captures)
return
}
curr.state = curr.state.out[0]
Expand All @@ -296,6 +349,7 @@ func (p *pattern) find(s string, pos int) []int {
return currList[0].captures
}

clear(visited)
clear(nextList)
nextList = nextList[:0]
c := s[pos]
Expand All @@ -310,6 +364,8 @@ func (p *pattern) find(s string, pos int) []int {
state: curr.state.out[0],
captures: curr.captures,
})
} else {
freeCaptures = append(freeCaptures, curr.captures)
}
}
}
Expand Down
101 changes: 101 additions & 0 deletions internal/lua/stringlib_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,12 @@ func TestStringFind(t *testing.T) {
init: 1,
want: []any{int64(1), int64(0)},
},
{
s: "abc",
pattern: "",
init: 1,
want: []any{int64(1), int64(0)},
},
{
s: "aaa",
pattern: "^a",
Expand Down Expand Up @@ -572,6 +578,101 @@ func TestStringFind(t *testing.T) {
}
}

func BenchmarkStringFind(b *testing.B) {
ctx := context.Background()
state := new(State)
defer func() {
if err := state.Close(); err != nil {
b.Error("Close:", err)
}
}()

state.PushClosure(0, OpenString)
if err := state.Call(ctx, 0, 1); err != nil {
b.Fatal(err)
}
if _, err := state.Field(ctx, -1, "find"); err != nil {
b.Fatal(err)
}

benchmarks := []struct {
name string
s string
pattern string
wantStart int64
wantEnd int64
}{
{
name: "SingleByte",
s: "abc",
pattern: "b",
wantStart: 2,
wantEnd: 2,
},
{
name: "Word",
s: "aaabbbccc",
pattern: "bbb",
wantStart: 4,
wantEnd: 6,
},
{
name: "SpaceSeparatedFields",
s: "foo bar baz quux xyzzy",
pattern: ".* .* .* .* .*",
wantStart: 1,
wantEnd: 22,
},
{
name: "SpaceSeparatedCaptures",
s: "foo bar baz quux xyzzy",
pattern: "(.*) (.*) (.*) (.*) (.*)",
wantStart: 1,
wantEnd: 22,
},
// Test case presented in introduction of https://swtch.com/~rsc/regexp/regexp1.html
{
name: "WorstCase",
s: strings.Repeat("a", 30),
pattern: strings.Repeat("a?", 30) + strings.Repeat("a", 30),
wantStart: 1,
wantEnd: 30,
},
}

for _, benchmark := range benchmarks {
b.Run(benchmark.name, func(b *testing.B) {
defer state.SetTop(state.Top())

b.SetBytes(int64(len(benchmark.s)))
for range b.N {
state.PushValue(-1)
state.PushString(benchmark.s)
state.PushString(benchmark.pattern)
if err := state.Call(ctx, 2, 2); err != nil {
b.Fatal(err)
}

start, startOK := state.ToInteger(-2)
end, endOK := state.ToInteger(-1)
if !startOK {
b.Errorf("type(select(1, string.find(%s, %s))) = %v; want integer",
lualex.Quote(benchmark.s), lualex.Quote(benchmark.pattern), state.Type(-2))
}
if !endOK {
b.Errorf("type(select(2, string.find(%s, %s))) = %v; want integer",
lualex.Quote(benchmark.s), lualex.Quote(benchmark.pattern), state.Type(-1))
}
if startOK && endOK && (start != benchmark.wantStart || end != benchmark.wantEnd) {
b.Errorf("string.find(%s, %s) = %d, %d; want %d, %d",
lualex.Quote(benchmark.s), lualex.Quote(benchmark.pattern), start, end, benchmark.wantStart, benchmark.wantEnd)
}
state.Pop(2)
}
})
}
}

func TestStringMatch(t *testing.T) {
tests := []struct {
s string
Expand Down

0 comments on commit c649321

Please sign in to comment.