Skip to content

Commit

Permalink
(concurrentbatchprocessor): Fix deadlock, test EarlyReturn feature (#257
Browse files Browse the repository at this point in the history
)

This tests and fixes a deadlock in the concurrent batch processor, one
that was especially easy to see with EarlyReturn=true, however exists
either way. The bug was introduced in
#251 when it removed a
`defer func()` from `consumeAndWait()` by mistake.

This is now covered by testing. This is fixed w/o a `defer func() { ...
}`. The older code was both counting responses and updating a semaphore.
Now that we only count responses and do not update a semaphore, there is
no need to defer. The batcher will avoid sending to the waiter when the
waiter's context is canceled.
  • Loading branch information
jmacd authored Oct 1, 2024
1 parent 5ca27ef commit 58c3bdf
Show file tree
Hide file tree
Showing 2 changed files with 152 additions and 19 deletions.
40 changes: 29 additions & 11 deletions collector/processor/concurrentbatchprocessor/batch_processor.go
Original file line number Diff line number Diff line change
Expand Up @@ -116,12 +116,15 @@ type shard struct {
}

// pendingItem is stored parallel to a pending batch and records
// how many items the waiter submitted, used to ensure the correct
// response count is returned to each waiter.
// how many items the waiter submitted, which is used to link
// the incoming and outgoing traces.
type pendingItem struct {
parentCtx context.Context
numItems int
respCh chan countedError

// respCh is non-nil when the caller is waiting for error
// transmission.
respCh chan countedError
}

// dataItem is exchanged between the waiter and the batching process
Expand Down Expand Up @@ -429,8 +432,15 @@ func (b *shard) sendItems(trigger trigger) {
// terminates.
parentSpan.End()

for _, pending := range thisBatch {
pending.waiter <- countedError{err: err, count: pending.count}
if !b.processor.earlyReturn {
for _, pending := range thisBatch {
select {
case pending.waiter <- countedError{err: err, count: pending.count}:
// OK! Caller received error and count.
case <-pending.ctx.Done():
// OK! Caller context was canceled.
}
}
}

if err != nil {
Expand Down Expand Up @@ -480,7 +490,7 @@ func allSameContext(x []pendingTuple) bool {
return true
}

func (b *shard) consumeAndWait(ctx context.Context, data any) error {
func (b *shard) consumeBatch(ctx context.Context, data any) error {
var itemCount int
switch telem := data.(type) {
case ptrace.Traces:
Expand All @@ -495,7 +505,11 @@ func (b *shard) consumeAndWait(ctx context.Context, data any) error {
return nil
}

respCh := make(chan countedError, 1)
var respCh chan countedError
if !b.processor.earlyReturn {
respCh = make(chan countedError, 1)
}

item := dataItem{
data: data,
pendingItem: pendingItem{
Expand All @@ -514,6 +528,10 @@ func (b *shard) consumeAndWait(ctx context.Context, data any) error {
}
}

return b.waitForItems(ctx, item.numItems, respCh)
}

func (b *shard) waitForItems(ctx context.Context, numItems int, respCh chan countedError) error {
var err error
for {
select {
Expand All @@ -523,8 +541,8 @@ func (b *shard) consumeAndWait(ctx context.Context, data any) error {
err = errors.Join(err, cntErr)
}

item.numItems -= cntErr.count
if item.numItems != 0 {
numItems -= cntErr.count
if numItems != 0 {
continue
}

Expand All @@ -543,7 +561,7 @@ type singleShardBatcher struct {
}

func (sb *singleShardBatcher) consume(ctx context.Context, data any) error {
return sb.batcher.consumeAndWait(ctx, data)
return sb.batcher.consumeBatch(ctx, data)
}

func (sb *singleShardBatcher) currentMetadataCardinality() int {
Expand Down Expand Up @@ -613,7 +631,7 @@ func (sb *multiShardBatcher) consume(ctx context.Context, data any) error {
sb.lock.Unlock()
}

return b.(*shard).consumeAndWait(ctx, data)
return b.(*shard).consumeBatch(ctx, data)
}

func (sb *multiShardBatcher) currentMetadataCardinality() int {
Expand Down
131 changes: 123 additions & 8 deletions collector/processor/concurrentbatchprocessor/batch_processor_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -403,6 +403,14 @@ func TestBatchProcessorSpansDeliveredEnforceBatchSize(t *testing.T) {
}

func TestBatchProcessorTracesSentBySize(t *testing.T) {
for _, early := range []bool{true, false} {
t.Run(fmt.Sprint("early=", early), func(t *testing.T) {
testBatchProcessorTracesSentBySize(t, early)
})
}
}

func testBatchProcessorTracesSentBySize(t *testing.T, early bool) {
tel := setupTestTelemetry()
sizer := &ptrace.ProtoMarshaler{}
sink := new(consumertest.TracesSink)
Expand All @@ -411,6 +419,7 @@ func TestBatchProcessorTracesSentBySize(t *testing.T) {
sendBatchSize := 20
cfg.SendBatchSize = uint32(sendBatchSize)
cfg.Timeout = 500 * time.Millisecond
cfg.EarlyReturn = early
creationSet := tel.NewSettings()
creationSet.MetricsLevel = configtelemetry.LevelDetailed
batcher, err := newBatchTracesProcessor(creationSet, sink, cfg)
Expand Down Expand Up @@ -527,6 +536,14 @@ func TestBatchProcessorTracesSentBySize(t *testing.T) {
}

func TestBatchProcessorTracesSentByMaxSize(t *testing.T) {
for _, early := range []bool{true, false} {
t.Run(fmt.Sprint("early=", early), func(t *testing.T) {
testBatchProcessorTracesSentByMaxSize(t, early)
})
}
}

func testBatchProcessorTracesSentByMaxSize(t *testing.T, early bool) {
tel := setupTestTelemetry()
sizer := &ptrace.ProtoMarshaler{}
sink := new(consumertest.TracesSink)
Expand All @@ -537,6 +554,7 @@ func TestBatchProcessorTracesSentByMaxSize(t *testing.T) {
cfg.SendBatchSize = uint32(sendBatchSize)
cfg.SendBatchMaxSize = uint32(sendBatchMaxSize)
cfg.Timeout = 500 * time.Millisecond
cfg.EarlyReturn = early
creationSet := tel.NewSettings()
creationSet.MetricsLevel = configtelemetry.LevelDetailed
batcher, err := newBatchTracesProcessor(creationSet, sink, cfg)
Expand All @@ -554,11 +572,17 @@ func TestBatchProcessorTracesSentByMaxSize(t *testing.T) {
sendTraces(bg, t, batcher, &wg, testdata.GenerateTraces(spansPerRequest))

wg.Wait()
require.NoError(t, batcher.Shutdown(bg))

// We expect at least one timeout period because (items % sendBatchMaxSize) != 0.
elapsed := time.Since(start)
require.GreaterOrEqual(t, elapsed.Nanoseconds(), cfg.Timeout.Nanoseconds())
if early {
// No timeout because early return.
require.Less(t, elapsed.Nanoseconds(), cfg.Timeout.Nanoseconds())
} else {
// We expect at least one timeout period because (items % sendBatchMaxSize) != 0.
require.GreaterOrEqual(t, elapsed.Nanoseconds(), cfg.Timeout.Nanoseconds())
}

require.NoError(t, batcher.Shutdown(bg))

// The max batch size is not a divisor of the total number of spans
expectedBatchesNum := int(math.Ceil(float64(totalSpans) / float64(sendBatchMaxSize)))
Expand Down Expand Up @@ -672,12 +696,21 @@ func TestBatchProcessorTracesSentByMaxSize(t *testing.T) {
}

func TestBatchProcessorSentByTimeout(t *testing.T) {
for _, early := range []bool{true, false} {
t.Run(fmt.Sprint("early=", early), func(t *testing.T) {
testBatchProcessorSentByTimeout(t, early)
})
}
}

func testBatchProcessorSentByTimeout(t *testing.T, early bool) {
bg := context.Background()
sink := new(consumertest.TracesSink)
cfg := createDefaultConfig().(*Config)
sendBatchSize := 100
cfg.SendBatchSize = uint32(sendBatchSize)
cfg.Timeout = 100 * time.Millisecond
cfg.EarlyReturn = early

requestCount := 5
spansPerRequest := 10
Expand All @@ -698,8 +731,13 @@ func TestBatchProcessorSentByTimeout(t *testing.T) {

wg.Wait()
elapsed := time.Since(start)
// We expect no timeout periods because (items % sendBatchMaxSize) == 0.
require.LessOrEqual(t, cfg.Timeout.Nanoseconds(), elapsed.Nanoseconds())
if early {
// We expect no timeout period because early return.
require.LessOrEqual(t, elapsed.Nanoseconds(), cfg.Timeout.Nanoseconds())
} else {
// We expect timeout periods because (items < sendBatchMaxSize).
require.Greater(t, elapsed.Nanoseconds(), cfg.Timeout.Nanoseconds())
}
require.NoError(t, batcher.Shutdown(context.Background()))

expectedBatchesNum := 1
Expand Down Expand Up @@ -946,10 +984,19 @@ func TestBatchMetrics_UnevenBatchMaxSize(t *testing.T) {
}

func TestBatchMetricsProcessor_Timeout(t *testing.T) {
for _, early := range []bool{true, false} {
t.Run(fmt.Sprint("early=", early), func(t *testing.T) {
testBatchMetricsProcessor_Timeout(t, early)
})
}
}

func testBatchMetricsProcessor_Timeout(t *testing.T, early bool) {
bg := context.Background()
cfg := Config{
Timeout: 100 * time.Millisecond,
SendBatchSize: 101,
EarlyReturn: early,
}
requestCount := 5
metricsPerRequest := 10
Expand All @@ -970,8 +1017,13 @@ func TestBatchMetricsProcessor_Timeout(t *testing.T) {

wg.Wait()
elapsed := time.Since(start)
// We expect at no timeout periods because (items % sendBatchMaxSize) == 0.
require.LessOrEqual(t, cfg.Timeout.Nanoseconds(), elapsed.Nanoseconds())
if early {
// We expect no timeout periods because early return.
require.LessOrEqual(t, elapsed.Nanoseconds(), cfg.Timeout.Nanoseconds())
} else {
// We expect a timeout period because (items < sendBatchSize).
require.Greater(t, elapsed.Nanoseconds(), cfg.Timeout.Nanoseconds())
}
require.NoError(t, batcher.Shutdown(context.Background()))

expectedBatchesNum := 1
Expand Down Expand Up @@ -1312,9 +1364,18 @@ func TestBatchLogProcessor_BatchSize(t *testing.T) {
}

func TestBatchLogsProcessor_Timeout(t *testing.T) {
for _, early := range []bool{true, false} {
t.Run(fmt.Sprint("early=", early), func(t *testing.T) {
testBatchLogsProcessor_Timeout(t, early)
})
}
}

func testBatchLogsProcessor_Timeout(t *testing.T, early bool) {
cfg := Config{
Timeout: 3 * time.Second,
SendBatchSize: 100,
EarlyReturn: early,
}
bg := context.Background()
requestCount := 5
Expand All @@ -1338,7 +1399,11 @@ func TestBatchLogsProcessor_Timeout(t *testing.T) {
wg.Wait()
elapsed := time.Since(start)
// We expect no timeout periods because (items % sendBatchMaxSize) == 0.
require.LessOrEqual(t, cfg.Timeout.Nanoseconds(), elapsed.Nanoseconds())
if early {
require.LessOrEqual(t, elapsed.Nanoseconds(), cfg.Timeout.Nanoseconds())
} else {
require.Greater(t, elapsed.Nanoseconds(), cfg.Timeout.Nanoseconds())
}
require.NoError(t, batcher.Shutdown(bg))

expectedBatchesNum := 1
Expand Down Expand Up @@ -1901,3 +1966,53 @@ func TestBatchProcessorEarlyReturn(t *testing.T) {
}
}
}

type blockingTracesSink struct {
consumertest.TracesSink
cancel context.CancelFunc
}

func (bts *blockingTracesSink) ConsumeTraces(ctx context.Context, td ptrace.Traces) error {
bts.cancel()
return bts.TracesSink.ConsumeTraces(ctx, td)
}

func TestBatchProcessorContextCanceledNoEarlyReturn(t *testing.T) {
bg := context.Background()
ctx, cancel := context.WithCancel(bg)
sink := &blockingTracesSink{
cancel: cancel,
}
cfg := createDefaultConfig().(*Config)
cfg.SendBatchSize = 10
cfg.SendBatchMaxSize = 10
cfg.EarlyReturn = false
creationSet := processortest.NewNopSettings()
batcher, err := newBatchTracesProcessor(creationSet, sink, cfg)
require.NoError(t, err)
require.NoError(t, batcher.Start(bg, componenttest.NewNopHost()))

spansPerRequest := 1000
var wg sync.WaitGroup
td := testdata.GenerateTraces(spansPerRequest)
// The request will be canceled by the first batch.
wg.Add(1)
go func() {
defer wg.Done()
err := batcher.ConsumeTraces(ctx, td)
assert.ErrorIs(t, err, context.Canceled)
}()

wg.Wait()

// Accepted data continues sending despite the canceled input
// context, so we will receive all the spans that were sent.
// However the client will give up early and stop reading from
// the response channel, therefore this tests whether we avoid
// blocking on the abandoned response channel.
assert.Eventually(t, func() bool {
return spansPerRequest == sink.SpanCount()
}, 10*time.Second, 100*time.Millisecond)

require.NoError(t, batcher.Shutdown(context.Background()))
}

0 comments on commit 58c3bdf

Please sign in to comment.