Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix race condition in glob collections #10

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 32 additions & 7 deletions cache.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package diderot

import (
"fmt"
"sync"
"time"

"github.com/linkedin/diderot/ads"
Expand Down Expand Up @@ -216,24 +217,44 @@ func (c *cache[T]) IsSubscribedTo(name string, handler ads.SubscriptionHandler[T
}

func (c *cache[T]) Subscribe(name string, handler ads.SubscriptionHandler[T]) {
wait := func(wgs ...*sync.WaitGroup) {
for _, wg := range wgs {
wg.Wait()
}
}

if name == ads.WildcardSubscription {
subscribedAt, version := c.wildcardSubscribers.Subscribe(handler)
c.EntryNames(func(name string) bool {

var waitGroups []*sync.WaitGroup
for name := range c.EntryNames {
// Cannot call c.Subscribe here because it always creates a backing watchableValue if it does not
// already exist. For wildcard subscriptions, if the entry doesn't exist (or in this case has been
// deleted), a subscription isn't necessary. If the entry reappears, it will be automatically
// subscribed to.
c.resources.ComputeIfPresent(name, func(name string, value *internal.WatchableValue[T]) {
value.NotifyHandlerAfterSubscription(handler, internal.WildcardSubscription, subscribedAt, version)
wg := value.NotifyHandlerAfterSubscription(
handler,
internal.WildcardSubscription,
subscribedAt,
version,
)
if wg != nil {
waitGroups = append(waitGroups, wg)
}
})
return true
})
}
wait(waitGroups...)
} else if gcURL, err := ads.ParseGlobCollectionURL(name, c.trimmedTypeURL); err == nil {
c.globCollections.Subscribe(gcURL, handler)
wait(c.globCollections.Subscribe(gcURL, handler)...)
} else {
var wg *sync.WaitGroup
c.createOrModifyEntry(name, func(name string, value *internal.WatchableValue[T]) {
value.Subscribe(handler)
wg = value.Subscribe(handler)
})
if wg != nil {
wg.Wait()
}
}
}

Expand Down Expand Up @@ -337,7 +358,11 @@ type cacheWithPriority[T proto.Message] struct {
func (c *cacheWithPriority[T]) Clear(name string, clearedAt time.Time) {
var shouldDelete bool
c.resources.ComputeIfPresent(name, func(name string, value *internal.WatchableValue[T]) {
shouldDelete = value.Clear(c.p, clearedAt) && value.SubscriberSets[internal.ExplicitSubscription].Size() == 0
isFullClear := value.Clear(c.p, clearedAt)
if gcURL, err := ads.ExtractGlobCollectionURLFromResourceURN(name, c.trimmedTypeURL); err == nil {
c.globCollections.RemoveValueFromCollection(gcURL, value)
}
shouldDelete = isFullClear && value.SubscriberSets[internal.ExplicitSubscription].Size() == 0
})
if shouldDelete {
c.deleteEntryIfNilAndNoSubscribers(name)
Expand Down
165 changes: 157 additions & 8 deletions cache_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,12 @@ package diderot_test
import (
"fmt"
"maps"
"math/rand/v2"
"slices"
"sort"
"strconv"
"sync"
"sync/atomic"
"testing"
"time"

Expand Down Expand Up @@ -33,6 +36,8 @@ const (
name3 = "r3"
)

var globCollectionPrefix = ads.XDSTPScheme + "/" + diderot.TypeOf[*Timestamp]().TrimmedURL() + "/"

var noTime time.Time

func TestCacheCrud(t *testing.T) {
Expand Down Expand Up @@ -338,7 +343,7 @@ func TestNotifyMetadata(t *testing.T) {
// loop. The loop should abort and not call the remaining subscribers. It should instead restart and run through each
// subscriber with the updated value.
func TestWatchableValueUpdateCancel(t *testing.T) {
prefix := ads.XDSTPScheme + "/" + diderot.TypeOf[*Timestamp]().TrimmedURL() + "/foo/"
prefix := globCollectionPrefix + "foo/"

c := newCache()

Expand Down Expand Up @@ -442,20 +447,20 @@ func TestCacheEntryDeletion(t *testing.T) {
func TestCacheCollections(t *testing.T) {
c := diderot.NewCache[*Timestamp]()

const prefix = "xdstp:///google.protobuf.Timestamp/"
const prefix = "xdstp:///google.protobuf.Timestamp/a/"

h := make(testutils.ChanSubscriptionHandler[*Timestamp], 1)

c.Subscribe(prefix+"a/*", h)
h.WaitForDelete(t, prefix+"a/*")
c.Subscribe(prefix+"*", h)
h.WaitForDelete(t, prefix+"*")

c.Subscribe(prefix+"a/foo", h)
h.WaitForDelete(t, prefix+"a/foo")
c.Subscribe(prefix+"foo", h)
h.WaitForDelete(t, prefix+"foo")

var updates []testutils.ExpectedNotification[*Timestamp]
var deletes []testutils.ExpectedNotification[*Timestamp]
for i := 0; i < 5; i++ {
name, v := prefix+"a/"+strconv.Itoa(i), strconv.Itoa(i)
name, v := prefix+""+strconv.Itoa(i), strconv.Itoa(i)
updates = append(updates, testutils.ExpectUpdate(c.Set(name, v, Now(), noTime)))
deletes = append(deletes, testutils.ExpectDelete[*Timestamp](name))
}
Expand All @@ -468,7 +473,7 @@ func TestCacheCollections(t *testing.T) {

h.WaitForNotifications(t, deletes...)

h.WaitForDelete(t, prefix+"a/*")
h.WaitForDelete(t, prefix+"*")
}

// TestCache raw validates that the various *Raw methods on the cache work as expected. Namely, raw
Expand Down Expand Up @@ -798,3 +803,147 @@ func DisableTime(tb testing.TB) {
internal.SetTimeProvider(time.Now)
})
}

// The following tests flex various critical sections in the way glob collections are handled (along
// with almost all the cache), to attempt to trigger a race condition. The tests must be run with
// -race for this to have any use.
func TestGlobRace(t *testing.T) {
prefix := globCollectionPrefix + "foo/"

// This tests many writers all competing for writes against overlapping entries.
t.Run("update", func(t *testing.T) {
c := newCache()

const (
entries = 100
writers = 100
count = 100
readers = 100

doneVersion = "done"
)

entryNames := make([]string, entries)
for i := range entryNames {
name := prefix + strconv.Itoa(i)
entryNames[i] = name
}

var writesDone, readsDone sync.WaitGroup
writesDone.Add(writers)
readsDone.Add(writers * readers)

for range readers {
h := testutils.NewSubscriptionHandler(func(name string, r *ads.Resource[*Timestamp], _ ads.SubscriptionMetadata) {
if r.Version == doneVersion {
readsDone.Done()
}
})
c.Subscribe(ads.WildcardSubscription, h)
}

for range writers {
names := slices.Clone(entryNames)
shuffle := func() []string {
rand.Shuffle(len(names), func(i, j int) {
names[i], names[j] = names[j], names[i]
})
return names
}
go func() {
defer writesDone.Done()
for range count {
for _, name := range shuffle() {
c.Set(name, "1", new(Timestamp), time.Time{})
}
}
}()
}
go func() {
writesDone.Wait()
for _, name := range entryNames {
c.Set(name, doneVersion, new(Timestamp), time.Time{})
}
}()

readsDone.Wait()
})
// This tests subscribing to a collection that is currently being updated.
t.Run("subscribe", func(t *testing.T) {
c := newCache()

stop := make(chan struct{})
t.Cleanup(func() { close(stop) })

for i := range 1 {
name := prefix + strconv.Itoa(i)
go func() {
for {
select {
case <-stop:
return
default:
}
c.Set(name, "", new(Timestamp), noTime)
time.Sleep(time.Nanosecond)
c.Clear(name, noTime)
}
}()
}
h := testutils.NewSubscriptionHandler[*Timestamp](func(string, *ads.Resource[*Timestamp], ads.SubscriptionMetadata) {})

// 1000 chosen arbitrarily, can be increased to increase likelihood of race condition, but the test
// will take longer to run.
for range 1000 {
c.Subscribe(prefix+ads.WildcardSubscription, h)
}
})
// This tests multiple subscribers all subscribing at once.
t.Run("concurrent subscriptions", func(t *testing.T) {
c := newCache()

const (
entries = 100
subscribers = 100
)

for i := range entries {
name := prefix + strconv.Itoa(i)
c.Set(name, "", Now(), noTime)
}

var done sync.WaitGroup
done.Add(subscribers)

// Set to true if inFlight is greater than one.
var multipleInFlight atomic.Bool
// Incremented when the Subscription loop starts, and decremented when it ends. If only one
// subscriber can go through the elements of a glob collection at once, then this will never be
// greater than once, and multipleInFlight will therefore never be set to true, failing the test.
var inFlight atomic.Int32

for range subscribers {
go func() {
defer done.Done()
var remainingEntries atomic.Int32
remainingEntries.Add(entries)
handlerFunc := func(string, *ads.Resource[*Timestamp], ads.SubscriptionMetadata) {
switch remainingEntries.Add(-1) {
case entries - 1:
if inFlight.Add(1) > 1 {
multipleInFlight.Store(true)
}
case 0:
inFlight.Add(-1)
}
}

c.Subscribe(prefix+ads.WildcardSubscription, testutils.NewSubscriptionHandler[*Timestamp](handlerFunc))
}()
}

done.Wait()
require.Equal(t, int32(0), inFlight.Load())
require.True(t, multipleInFlight.Load())
})
}
Loading