Skip to content

Fix race condition in glob collections #10

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 6 commits into from
Mar 14, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -37,15 +37,15 @@ COVERPKG = .

test:
@mkdir -p $(dir $(COVERAGE))
go test -race -coverprofile=$(COVERAGE) -coverpkg=$(COVERPKG)/... -count=$(TESTCOUNT) $(TESTFLAGS) $(TESTPKG)
go test -v -race -coverprofile=$(COVERAGE) -coverpkg=$(COVERPKG)/... -count=$(TESTCOUNT) $(TESTFLAGS) $(TESTPKG)
go tool cover -func $(COVERAGE) | awk '/total:/{print "Coverage: "$$3}'

.PHONY: $(COVERAGE)

coverage: $(COVERAGE)
$(COVERAGE):
@mkdir -p $(@D)
$(MAKE) test TESTFLAGS='-coverprofile=$(COVERAGE) -coverpkg=$(COVERPKG)/...'
-$(MAKE) test
go tool cover -html=$(COVERAGE)

profile_cache:
Expand Down
47 changes: 40 additions & 7 deletions cache.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package diderot

import (
"fmt"
"sync"
"time"

"github.com/linkedin/diderot/ads"
Expand Down Expand Up @@ -212,23 +213,51 @@ func (c *cache[T]) IsSubscribedTo(name string, handler ads.SubscriptionHandler[T
}

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is the comment on line 90 and 91 still true?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, absolutely! It's only when it's the same SubscriptionHandler. However, different goroutines can still interact with the cache as long as it's for different SubscriptionHandlers.

func (c *cache[T]) Subscribe(name string, handler ads.SubscriptionHandler[T]) {
// More details about this can be found in the documentation of WatchableValue. But the short version
// of it is that because Subscribe needs to not return until all notifications are delivered, and
// subscribing to a WatchableValue returns a WaitGroup if the loop goroutine will deliver the
// notification instead of this goroutine. To respect the Subscribe contract, wait for all the
// returned WaitGroups. Crucially however, because the various Compute methods on the resource map
// hold locks, waiting on the WaitGroups must be done *outside* of the lambdas passed to those
// methods. Hence, to avoid accidentally blocking anything, wait for the WaitGroups in a deferred
// statement.
var waitGroups []*sync.WaitGroup
appendWg := func(wg *sync.WaitGroup) {
// Only add non-nil WaitGroups to the slice to avoid nil-pointer panics, but primarily to avoid
// allocating a large slice of mostly nil WaitGroups since it is rare for WatchableValue.Subscribe to
// actually return a non-nil WaitGroup.
if wg != nil {
waitGroups = append(waitGroups, wg)
}
}
defer func() {
for _, wg := range waitGroups {
wg.Wait()
}
}()

if name == ads.WildcardSubscription {
subscribedAt, version := c.wildcardSubscribers.Subscribe(handler)
c.EntryNames(func(name string) bool {

for name := range c.EntryNames {
// Cannot call c.Subscribe here because it always creates a backing watchableValue if it does not
// already exist. For wildcard subscriptions, if the entry doesn't exist (or in this case has been
// deleted), a subscription isn't necessary. If the entry reappears, it will be automatically
// subscribed to.
c.resources.ComputeIfPresent(name, func(name string, value *internal.WatchableValue[T]) {
value.NotifyHandlerAfterSubscription(handler, internal.WildcardSubscription, subscribedAt, version)
appendWg(value.NotifyHandlerAfterSubscription(
handler,
internal.WildcardSubscription,
subscribedAt,
version,
))
})
return true
})
}
} else if gcURL, err := ads.ParseGlobCollectionURL[T](name); err == nil {
c.globCollections.Subscribe(gcURL, handler)
waitGroups = c.globCollections.Subscribe(gcURL, handler)
} else {
c.createOrModifyEntry(name, func(name string, value *internal.WatchableValue[T]) {
value.Subscribe(handler)
appendWg(value.Subscribe(handler))
})
}
}
Expand Down Expand Up @@ -347,7 +376,11 @@ type cacheWithPriority[T proto.Message] struct {
func (c *cacheWithPriority[T]) Clear(name string, clearedAt time.Time) {
var shouldDelete bool
c.resources.ComputeIfPresent(name, func(name string, value *internal.WatchableValue[T]) {
shouldDelete = value.Clear(c.p, clearedAt) && value.SubscriberSets[internal.ExplicitSubscription].Size() == 0
isFullClear := value.Clear(c.p, clearedAt)
if gcURL, err := parseGlobCollectionURN[T](name); err == nil {
c.globCollections.RemoveValueFromCollection(gcURL, value)
}
shouldDelete = isFullClear && value.SubscriberSets[internal.ExplicitSubscription].Size() == 0
})
if shouldDelete {
c.deleteEntryIfNilAndNoSubscribers(name)
Expand Down
165 changes: 157 additions & 8 deletions cache_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,12 @@ package diderot_test
import (
"fmt"
"maps"
"math/rand/v2"
"slices"
"sort"
"strconv"
"sync"
"sync/atomic"
"testing"
"time"

Expand Down Expand Up @@ -33,6 +36,8 @@ const (
name3 = "r3"
)

var globCollectionPrefix = ads.XDSTPScheme + "/" + diderot.TypeOf[*Timestamp]().TrimmedURL() + "/"

var noTime time.Time

func TestCacheCrud(t *testing.T) {
Expand Down Expand Up @@ -338,7 +343,7 @@ func TestNotifyMetadata(t *testing.T) {
// loop. The loop should abort and not call the remaining subscribers. It should instead restart and run through each
// subscriber with the updated value.
func TestWatchableValueUpdateCancel(t *testing.T) {
prefix := ads.XDSTPScheme + "/" + diderot.TypeOf[*Timestamp]().TrimmedURL() + "/foo/"
prefix := globCollectionPrefix + "foo/"

c := newCache()

Expand Down Expand Up @@ -442,20 +447,20 @@ func TestCacheEntryDeletion(t *testing.T) {
func TestCacheCollections(t *testing.T) {
c := diderot.NewCache[*Timestamp]()

const prefix = "xdstp:///google.protobuf.Timestamp/"
const prefix = "xdstp:///google.protobuf.Timestamp/a/"

h := make(testutils.ChanSubscriptionHandler[*Timestamp], 1)

c.Subscribe(prefix+"a/*", h)
h.WaitForDelete(t, prefix+"a/*")
c.Subscribe(prefix+"*", h)
h.WaitForDelete(t, prefix+"*")

c.Subscribe(prefix+"a/foo", h)
h.WaitForDelete(t, prefix+"a/foo")
c.Subscribe(prefix+"foo", h)
h.WaitForDelete(t, prefix+"foo")

var updates []testutils.ExpectedNotification[*Timestamp]
var deletes []testutils.ExpectedNotification[*Timestamp]
for i := 0; i < 5; i++ {
name, v := prefix+"a/"+strconv.Itoa(i), strconv.Itoa(i)
name, v := prefix+""+strconv.Itoa(i), strconv.Itoa(i)
updates = append(updates, testutils.ExpectUpdate(c.Set(name, v, Now(), noTime)))
deletes = append(deletes, testutils.ExpectDelete[*Timestamp](name))
}
Expand All @@ -468,7 +473,7 @@ func TestCacheCollections(t *testing.T) {

h.WaitForNotifications(t, deletes...)

h.WaitForDelete(t, prefix+"a/*")
h.WaitForDelete(t, prefix+"*")
}

// TestCache raw validates that the various *Raw methods on the cache work as expected. Namely, raw
Expand Down Expand Up @@ -798,3 +803,147 @@ func DisableTime(tb testing.TB) {
internal.SetTimeProvider(time.Now)
})
}

// The following tests flex various critical sections in the way glob collections are handled (along
// with almost all the cache), to attempt to trigger a race condition. The tests must be run with
// -race for this to have any use.
func TestGlobRace(t *testing.T) {
prefix := globCollectionPrefix + "foo/"

// This tests many writers all competing for writes against overlapping entries.
t.Run("update", func(t *testing.T) {
c := newCache()

const (
entries = 100
writers = 100
count = 100
readers = 100

doneVersion = "done"
)

entryNames := make([]string, entries)
for i := range entryNames {
name := prefix + strconv.Itoa(i)
entryNames[i] = name
}

var writesDone, readsDone sync.WaitGroup
writesDone.Add(writers)
readsDone.Add(writers * readers)

for range readers {
h := testutils.NewSubscriptionHandler(func(name string, r *ads.Resource[*Timestamp], _ ads.SubscriptionMetadata) {
if r.Version == doneVersion {
readsDone.Done()
}
})
c.Subscribe(ads.WildcardSubscription, h)
}

for range writers {
names := slices.Clone(entryNames)
shuffle := func() []string {
rand.Shuffle(len(names), func(i, j int) {
names[i], names[j] = names[j], names[i]
})
return names
}
go func() {
defer writesDone.Done()
for range count {
for _, name := range shuffle() {
c.Set(name, "1", new(Timestamp), time.Time{})
}
}
}()
}
go func() {
writesDone.Wait()
for _, name := range entryNames {
c.Set(name, doneVersion, new(Timestamp), time.Time{})
}
}()

readsDone.Wait()
})
// This tests subscribing to a collection whose entries are being concurrently added/updated/cleared.
t.Run("subscribe", func(t *testing.T) {
c := newCache()

stop := make(chan struct{})
t.Cleanup(func() { close(stop) })

for i := range 1 {
name := prefix + strconv.Itoa(i)
go func() {
for {
select {
case <-stop:
return
default:
}
c.Set(name, "", new(Timestamp), noTime)
time.Sleep(time.Nanosecond)
c.Clear(name, noTime)
}
}()
}
h := testutils.NewSubscriptionHandler[*Timestamp](func(string, *ads.Resource[*Timestamp], ads.SubscriptionMetadata) {})

// 1000 chosen arbitrarily, can be increased to increase likelihood of race condition, but the test
// will take longer to run.
for range 1000 {
c.Subscribe(prefix+ads.WildcardSubscription, h)
}
})
// This tests multiple subscribers all subscribing at once.
t.Run("concurrent subscriptions", func(t *testing.T) {
c := newCache()

const (
entries = 100
subscribers = 100
)

for i := range entries {
name := prefix + strconv.Itoa(i)
c.Set(name, "", Now(), noTime)
}

var done sync.WaitGroup
done.Add(subscribers)

// Set to true if inFlight is greater than one.
var multipleInFlight atomic.Bool
// Incremented when the Subscription loop starts, and decremented when it ends. If only one
// subscriber can go through the elements of a glob collection at once, then this will never be
// greater than once, and multipleInFlight will therefore never be set to true, failing the test.
var inFlight atomic.Int32

for range subscribers {
go func() {
defer done.Done()
var remainingEntries atomic.Int32
remainingEntries.Add(entries)
handlerFunc := func(string, *ads.Resource[*Timestamp], ads.SubscriptionMetadata) {
switch remainingEntries.Add(-1) {
case entries - 1:
if inFlight.Add(1) > 1 {
multipleInFlight.Store(true)
}
case 0:
inFlight.Add(-1)
}
}

c.Subscribe(prefix+ads.WildcardSubscription, testutils.NewSubscriptionHandler[*Timestamp](handlerFunc))
}()
}

done.Wait()
require.Equal(t, int32(0), inFlight.Load())
require.True(t, multipleInFlight.Load())
})
}
Loading