Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
53 changes: 12 additions & 41 deletions pkg/distributor/distributor.go
Original file line number Diff line number Diff line change
Expand Up @@ -454,29 +454,6 @@ type streamTracker struct {
failed atomic.Int32
}

// TODO taken from Cortex, see if we can refactor out an usable interface.
type pushTracker struct {
streamsPending atomic.Int32
streamsFailed atomic.Int32
done chan struct{}
err chan error
}

// doneWithResult records the result of a stream push.
// If err is nil, the stream push is considered successful.
// If err is not nil, the stream push is considered failed.
func (p *pushTracker) doneWithResult(err error) {
if err == nil {
if p.streamsPending.Dec() == 0 {
p.done <- struct{}{}
}
} else {
if p.streamsFailed.Inc() == 1 {
p.err <- err
}
}
}

func (d *Distributor) waitSimulatedLatency(ctx context.Context, tenantID string, start time.Time) {
latency := d.validator.SimulatedPushLatency(tenantID)
if latency > 0 {
Expand Down Expand Up @@ -754,10 +731,7 @@ func (d *Distributor) PushWithResolver(ctx context.Context, req *logproto.PushRe
const maxExpectedReplicationSet = 5 // typical replication factor 3 plus one for inactive plus one for luck
var descs [maxExpectedReplicationSet]ring.InstanceDesc

tracker := pushTracker{
done: make(chan struct{}, 1), // buffer avoids blocking if caller terminates - sendSamples() only sends once on each
err: make(chan error, 1),
}
tracker := newBasicPushTracker()
streamsToWrite := 0
if d.cfg.IngesterEnabled {
streamsToWrite += len(streams)
Expand All @@ -766,15 +740,15 @@ func (d *Distributor) PushWithResolver(ctx context.Context, req *logproto.PushRe
streamsToWrite += len(streams)
}
// We must correctly set streamsPending before beginning any writes to ensure we don't have a race between finishing all of one path before starting the other.
tracker.streamsPending.Store(int32(streamsToWrite))
tracker.Add(int32(streamsToWrite))

if d.cfg.KafkaEnabled {
subring, err := d.partitionRing.PartitionRing().ShuffleShard(tenantID, d.validator.IngestionPartitionsTenantShardSize(tenantID))
if err != nil {
return nil, err
}
// We don't need to create a new context like the ingester writes, because we don't return unless all writes have succeeded.
d.sendStreamsToKafka(ctx, streams, tenantID, &tracker, subring)
d.sendStreamsToKafka(ctx, streams, tenantID, tracker, subring)
}

if d.cfg.IngesterEnabled {
Expand Down Expand Up @@ -823,7 +797,7 @@ func (d *Distributor) PushWithResolver(ctx context.Context, req *logproto.PushRe
case d.ingesterTasks <- pushIngesterTask{
ingester: ingester,
streamTracker: samples,
pushTracker: &tracker,
pushTracker: tracker,
ctx: localCtx,
cancel: cancel,
}:
Expand All @@ -833,14 +807,11 @@ func (d *Distributor) PushWithResolver(ctx context.Context, req *logproto.PushRe
}
}

select {
case err := <-tracker.err:
if err := tracker.Wait(ctx); err != nil {
return nil, err
case <-tracker.done:
return &logproto.PushResponse{}, validationErr
case <-ctx.Done():
return nil, ctx.Err()
}

return &logproto.PushResponse{}, validationErr
}

// missingEnforcedLabels returns true if the stream is missing any of the required labels.
Expand Down Expand Up @@ -1135,7 +1106,7 @@ func (d *Distributor) truncateLines(vContext validationContext, stream *logproto

type pushIngesterTask struct {
streamTracker []*streamTracker
pushTracker *pushTracker
pushTracker PushTracker
ingester ring.InstanceDesc
ctx context.Context
cancel context.CancelFunc
Expand Down Expand Up @@ -1172,12 +1143,12 @@ func (d *Distributor) sendStreams(task pushIngesterTask) {
if task.streamTracker[i].failed.Inc() <= int32(task.streamTracker[i].maxFailures) {
continue
}
task.pushTracker.doneWithResult(err)
task.pushTracker.Done(err)
} else {
if task.streamTracker[i].succeeded.Inc() != int32(task.streamTracker[i].minSuccess) {
continue
}
task.pushTracker.doneWithResult(nil)
task.pushTracker.Done(nil)
}
}
}
Expand Down Expand Up @@ -1209,14 +1180,14 @@ func (d *Distributor) sendStreamsErr(ctx context.Context, ingester ring.Instance
return err
}

func (d *Distributor) sendStreamsToKafka(ctx context.Context, streams []KeyedStream, tenant string, tracker *pushTracker, subring *ring.PartitionRing) {
func (d *Distributor) sendStreamsToKafka(ctx context.Context, streams []KeyedStream, tenant string, tracker PushTracker, subring *ring.PartitionRing) {
for _, s := range streams {
go func(s KeyedStream) {
err := d.sendStreamToKafka(ctx, s, tenant, subring)
if err != nil {
err = fmt.Errorf("failed to write stream to kafka: %w", err)
}
tracker.doneWithResult(err)
tracker.Done(err)
}(s)
}
}
Expand Down
103 changes: 103 additions & 0 deletions pkg/distributor/tracker.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,103 @@
package distributor

import (
"context"
"sync"
)

// PushTracker is an interface to track the status of pushes and wait on
// their completion.
type PushTracker interface {
// Add increments the number of pushes. It must not be called after the
// last call to [Done] has completed.
Add(int32)

// Done decrements the number of pushes. It accepts an optional error
// if the push failed.
Done(err error)

// Wait until all pushes are done or a push fails, whichever happens
// first.
Wait(ctx context.Context) error
}

type basicPushTracker struct {
mtx sync.Mutex // protects the fields below.
n int32 // the number of pushes.
firstErr error // the first reported error from a push.
doneCh chan struct{} // closed when all pushes are done.
errCh chan struct{} // closed when an error is reported.
done bool // fast path, equivalent to select { case <-t.doneCh: default: }
}

// newBasicPushTracker returns a new, initialized [newSimplePushTracker].
func newBasicPushTracker() *basicPushTracker {
return &basicPushTracker{
doneCh: make(chan struct{}),
errCh: make(chan struct{}),
}
}

// Add implements the [PushTracker] interface.
func (t *basicPushTracker) Add(n int32) {
t.mtx.Lock()
defer t.mtx.Unlock()
if t.done {
panic("Add called after last call to Done")
}
t.n += n
if t.n < 0 {
// We panic on negative counters just like [sync.WaitGroup].
panic("Negative counter")
}
}

// Done implements the [PushTracker] interface.
func (t *basicPushTracker) Done(err error) {
t.mtx.Lock()
defer t.mtx.Unlock()
if t.n <= 0 {
// We panic here just like [sync.WaitGroup].
panic("Done called more times than Add")
}
if err != nil && t.firstErr == nil {
// errCh can never be closed twice as t.firstErr can never be nil
// more than once.
t.firstErr = err
close(t.errCh)
}
t.n--
if t.n == 0 {
close(t.doneCh)
t.done = true
}
}

// Wait implements the [PushTracker] interface.
func (t *basicPushTracker) Wait(ctx context.Context) error {
t.mtx.Lock()
// We need to have the mutex here as t.n can be modified as doneCh has
// not been closed, while t.firstErr can still be modified as neither
// doneCh nor errCh have been closed.
if t.firstErr != nil || t.n == 0 {
// We need to store the firstErr before releasing the mutex for the
// same reason.
res := t.firstErr
t.mtx.Unlock()
return res
}
t.mtx.Unlock()
select {
case <-ctx.Done():
return ctx.Err()
case <-t.doneCh:
// Must return t.firstErr as done is also closed if the last push
// failed. We don't need the mutex here as t.firstErr is never
// modified after doneCh is closed.
return t.firstErr
case <-t.errCh:
// We don't need the mutex here either as t.firstErr is never modified
// after errCh is closed.
return t.firstErr
}
}
139 changes: 139 additions & 0 deletions pkg/distributor/tracker_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,139 @@
package distributor

import (
"context"
"errors"
"math/rand"
"sync"
"testing"
"time"

"github.com/stretchr/testify/require"
)

func TestBasicPushTracker(t *testing.T) {
t.Run("a new tracker that has never been incremented should never block", func(t *testing.T) {
tracker := newBasicPushTracker()
ctx, cancel := context.WithTimeout(t.Context(), time.Second)
t.Cleanup(cancel)
require.NoError(t, tracker.Wait(ctx))
})

t.Run("a canceled context should return a context canceled error", func(t *testing.T) {
tracker := newBasicPushTracker()
tracker.Add(1)
ctx, cancel := context.WithTimeout(t.Context(), time.Millisecond)
t.Cleanup(cancel)
require.EqualError(t, tracker.Wait(ctx), "context deadline exceeded")
})

t.Run("a done tracker with no errors should return nil", func(t *testing.T) {
tracker := newBasicPushTracker()
tracker.Add(1)
tracker.Done(nil)
ctx, cancel := context.WithTimeout(t.Context(), time.Second)
t.Cleanup(cancel)
require.NoError(t, tracker.Wait(ctx))
})

t.Run("a done tracker with an error should return the error", func(t *testing.T) {
tracker := newBasicPushTracker()
tracker.Add(1)
tracker.Done(errors.New("an error occurred"))
ctx, cancel := context.WithTimeout(t.Context(), time.Second)
t.Cleanup(cancel)
require.EqualError(t, tracker.Wait(ctx), "an error occurred")
})

t.Run("a done tracker should return the first error that occurred", func(t *testing.T) {
tracker := newBasicPushTracker()
tracker.Add(2)
tracker.Done(errors.New("an error occurred"))
tracker.Done(errors.New("another error occurred"))
ctx, cancel := context.WithTimeout(t.Context(), time.Second)
t.Cleanup(cancel)
require.EqualError(t, tracker.Wait(ctx), "an error occurred")
})

t.Run("a done tracker should return at least one error", func(t *testing.T) {
t1 := newBasicPushTracker()
t1.Add(2)
t1.Done(nil)
t1.Done(errors.New("an error occurred"))
ctx, cancel := context.WithTimeout(t.Context(), time.Second)
t.Cleanup(cancel)
require.EqualError(t, t1.Wait(ctx), "an error occurred")
// And now test the opposite sequence.
t2 := newBasicPushTracker()
t2.Add(2)
t2.Done(errors.New("an error occurred"))
t2.Done(nil)
ctx, cancel = context.WithTimeout(t.Context(), time.Second)
t.Cleanup(cancel)
require.EqualError(t, t2.Wait(ctx), "an error occurred")
})

t.Run("more Done than Add should panic", func(t *testing.T) {
// Should panic if Done is called before Add.
require.PanicsWithValue(t, "Done called more times than Add", func() {
tracker := newBasicPushTracker()
tracker.Done(nil)
})
// Should panic if Done is called more times than Add.
require.PanicsWithValue(t, "Done called more times than Add", func() {
tracker := newBasicPushTracker()
tracker.Add(1)
tracker.Done(nil)
tracker.Done(nil)
})
})

t.Run("Add after Done should panic", func(t *testing.T) {
require.PanicsWithValue(t, "Add called after last call to Done", func() {
tracker := newBasicPushTracker()
tracker.Add(1)
tracker.Done(nil)
tracker.Add(1)
})
})

t.Run("Negative counter should panic", func(t *testing.T) {
require.PanicsWithValue(t, "Negative counter", func() {
tracker := newBasicPushTracker()
tracker.Add(-1)
})
})
}

// Run with go test -fuzz=FuzzBasicPushTracker.
func FuzzBasicPushTracker(f *testing.F) {
f.Add(uint16(100))
f.Fuzz(func(t *testing.T, n uint16) {
wg := sync.WaitGroup{}
ctx, cancel := context.WithTimeout(t.Context(), 5*time.Second)
t.Cleanup(cancel)
tracker := newBasicPushTracker()
tracker.Add(int32(n))
// Create a random number of waiters.
for i := 0; i < rand.Intn(100); i++ {
wg.Add(1)
go func() {
defer wg.Done()
// Sleep a random time up to 100ms.
time.Sleep(time.Duration(rand.Intn(100)) * time.Millisecond)
require.NoError(t, tracker.Wait(ctx))
}()
}
// Done should be called for each n, cannot be random.
for i := 0; i < int(n); i++ {
wg.Add(1)
go func() {
defer wg.Done()
// Sleep a random time up to 100ms too.
time.Sleep(time.Duration(rand.Intn(100)) * time.Millisecond)
tracker.Done(nil)
}()
}
wg.Wait()
})
}