-
Notifications
You must be signed in to change notification settings - Fork 530
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
3 changed files
with
161 additions
and
528 deletions.
There are no files selected for viewing
This file was deleted.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,196 +1,110 @@ | ||
// Forked with love from: https://github.com/prometheus/prometheus/tree/c954cd9d1d4e3530be2939d39d8633c38b70913f/util/pool | ||
// This package was forked to provide better protection against putting byte slices back into the pool that | ||
// did not originate from it. | ||
|
||
package pool | ||
|
||
import ( | ||
"context" | ||
"fmt" | ||
"sync" | ||
"time" | ||
|
||
"github.com/prometheus/client_golang/prometheus" | ||
"github.com/prometheus/client_golang/prometheus/promauto" | ||
"github.com/uber-go/atomic" | ||
) | ||
|
||
const ( | ||
queueLengthReportDuration = 15 * time.Second | ||
) | ||
|
||
var ( | ||
metricQueryQueueLength = promauto.NewGauge(prometheus.GaugeOpts{ | ||
Namespace: "tempodb", | ||
Name: "work_queue_length", | ||
Help: "Current length of the work queue.", | ||
}) | ||
|
||
metricQueryQueueMax = promauto.NewGauge(prometheus.GaugeOpts{ | ||
Namespace: "tempodb", | ||
Name: "work_queue_max", | ||
Help: "Maximum number of items in the work queue.", | ||
}) | ||
metricMissOver prometheus.Counter | ||
metricMissUnder prometheus.Counter | ||
) | ||
|
||
type JobFunc func(ctx context.Context, payload interface{}) (interface{}, error) | ||
|
||
type result struct { | ||
data interface{} | ||
err error | ||
} | ||
|
||
type job struct { | ||
ctx context.Context | ||
payload interface{} | ||
fn JobFunc | ||
func init() { | ||
metricAllocOutPool := promauto.NewCounterVec(prometheus.CounterOpts{ | ||
Namespace: "tempo", | ||
Name: "ingester_prealloc_miss_bytes_total", | ||
Help: "The total number of alloc'ed bytes that missed the sync pools.", | ||
}, []string{"direction"}) | ||
|
||
wg *sync.WaitGroup | ||
resultsCh chan result | ||
stop *atomic.Bool | ||
metricMissOver = metricAllocOutPool.WithLabelValues("over") | ||
metricMissUnder = metricAllocOutPool.WithLabelValues("under") | ||
} | ||
|
||
// Pool is a linearly bucketed pool for variably sized byte slices. | ||
type Pool struct { | ||
cfg *Config | ||
size *atomic.Int32 | ||
|
||
workQueue chan *job | ||
shutdownCh chan struct{} | ||
buckets []sync.Pool | ||
bktSize int | ||
minBucket int | ||
} | ||
|
||
func NewPool(cfg *Config) *Pool { | ||
if cfg == nil { | ||
cfg = defaultConfig() | ||
// New returns a new Pool with size buckets for minSize to maxSize | ||
func New(minBucket, numBuckets, bktSize int) *Pool { | ||
if minBucket < 0 { | ||
panic("invalid min bucket size") | ||
} | ||
|
||
q := make(chan *job, cfg.QueueDepth) | ||
p := &Pool{ | ||
cfg: cfg, | ||
workQueue: q, | ||
size: atomic.NewInt32(0), | ||
shutdownCh: make(chan struct{}), | ||
if bktSize < 1 { | ||
panic("invalid bucket size") | ||
} | ||
|
||
for i := 0; i < cfg.MaxWorkers; i++ { | ||
go p.worker(q) | ||
if numBuckets < 1 { | ||
panic("invalid num buckets") | ||
} | ||
|
||
p.reportQueueLength() | ||
|
||
metricQueryQueueMax.Set(float64(cfg.QueueDepth)) | ||
|
||
return p | ||
return &Pool{ | ||
buckets: make([]sync.Pool, numBuckets), | ||
bktSize: bktSize, | ||
minBucket: minBucket, | ||
} | ||
} | ||
|
||
func (p *Pool) RunJobs(ctx context.Context, payloads []interface{}, fn JobFunc) ([]interface{}, []error, error) { | ||
ctx, cancel := context.WithCancel(ctx) | ||
defer cancel() | ||
// Get returns a new byte slices that fits the given size. | ||
func (p *Pool) Get(sz int) []byte { | ||
if sz < 0 { | ||
panic("requested negative size") | ||
} | ||
|
||
totalJobs := len(payloads) | ||
// Find the right bucket. | ||
bkt := p.bucketFor(sz) | ||
|
||
// sanity check before we even attempt to start adding jobs | ||
if int(p.size.Load())+totalJobs > p.cfg.QueueDepth { | ||
return nil, nil, fmt.Errorf("queue doesn't have room for %d jobs", len(payloads)) | ||
if bkt < 0 { | ||
metricMissUnder.Add(float64(sz)) | ||
return make([]byte, 0, sz) | ||
} | ||
|
||
resultsCh := make(chan result, totalJobs) // way for jobs to send back results | ||
stop := atomic.NewBool(false) // way to signal to the jobs to quit | ||
wg := &sync.WaitGroup{} // way to wait for all jobs to complete | ||
|
||
// add each job one at a time. even though we checked length above these might still fail | ||
for _, payload := range payloads { | ||
wg.Add(1) | ||
j := &job{ | ||
ctx: ctx, | ||
fn: fn, | ||
payload: payload, | ||
wg: wg, | ||
resultsCh: resultsCh, | ||
stop: stop, | ||
} | ||
|
||
select { | ||
case p.workQueue <- j: | ||
p.size.Inc() | ||
default: | ||
wg.Done() | ||
stop.Store(true) | ||
return nil, nil, fmt.Errorf("failed to add a job to work queue") | ||
} | ||
if bkt >= len(p.buckets) { | ||
metricMissOver.Add(float64(sz)) | ||
return make([]byte, 0, sz) | ||
} | ||
|
||
// wait for all jobs to finish | ||
wg.Wait() | ||
|
||
// close resultsCh | ||
close(resultsCh) | ||
|
||
// read all from results channel | ||
var data []interface{} | ||
var funcErrs []error | ||
for result := range resultsCh { | ||
if result.err != nil { | ||
funcErrs = append(funcErrs, result.err) | ||
} else { | ||
data = append(data, result.data) | ||
} | ||
b := p.buckets[bkt].Get() | ||
if b == nil { | ||
alignedSz := (bkt+1)*p.bktSize + p.minBucket | ||
b = make([]byte, 0, alignedSz) | ||
} | ||
|
||
return data, funcErrs, nil | ||
return b.([]byte) | ||
} | ||
|
||
func (p *Pool) Shutdown() { | ||
close(p.workQueue) | ||
close(p.shutdownCh) | ||
} | ||
// Put adds a slice to the right bucket in the pool. | ||
func (p *Pool) Put(s []byte) int { | ||
c := cap(s) | ||
|
||
func (p *Pool) worker(j <-chan *job) { | ||
for { | ||
select { | ||
case <-p.shutdownCh: | ||
return | ||
case j, ok := <-j: | ||
if !ok { | ||
return | ||
} | ||
runJob(j) | ||
p.size.Dec() | ||
} | ||
// valid slice? | ||
if (c-p.minBucket)%p.bktSize != 0 { | ||
return -1 | ||
} | ||
bkt := p.bucketFor(c) // -1 puts the slice in the pool below. it will be larger than all requested slices for this bucket | ||
if bkt < 0 { | ||
return -1 | ||
} | ||
if bkt >= len(p.buckets) { | ||
return -1 | ||
} | ||
} | ||
|
||
func (p *Pool) reportQueueLength() { | ||
ticker := time.NewTicker(queueLengthReportDuration) | ||
go func() { | ||
defer ticker.Stop() | ||
for { | ||
select { | ||
case <-ticker.C: | ||
metricQueryQueueLength.Set(float64(p.size.Load())) | ||
case <-p.shutdownCh: | ||
return | ||
} | ||
} | ||
}() | ||
} | ||
|
||
func runJob(job *job) { | ||
defer job.wg.Done() | ||
p.buckets[bkt].Put(s) // nolint: staticcheck | ||
|
||
// bail in case we have been asked to stop | ||
if job.ctx.Err() != nil { | ||
return | ||
} | ||
return bkt // for testing | ||
} | ||
|
||
// bail in case not all jobs could be enqueued | ||
if job.stop.Load() { | ||
return | ||
func (p *Pool) bucketFor(sz int) int { | ||
if sz <= p.minBucket { | ||
return -1 | ||
} | ||
|
||
data, err := job.fn(job.ctx, job.payload) | ||
if data != nil || err != nil { | ||
select { | ||
case job.resultsCh <- result{ | ||
data: data, | ||
err: err, | ||
}: | ||
default: // if we hit default it means that something else already returned a good result. /shrug | ||
} | ||
} | ||
return (sz - p.minBucket - 1) / p.bktSize | ||
} |
Oops, something went wrong.