-
Notifications
You must be signed in to change notification settings - Fork 530
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
3 changed files
with
539 additions
and
109 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,14 @@ | ||
package pool | ||
|
||
type Config struct { | ||
MaxWorkers int `yaml:"max_workers"` | ||
QueueDepth int `yaml:"queue_depth"` | ||
} | ||
|
||
// default is concurrency disabled | ||
func defaultConfig() *Config { | ||
return &Config{ | ||
MaxWorkers: 30, | ||
QueueDepth: 10000, | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,110 +1,196 @@ | ||
// Forked with love from: https://github.com/prometheus/prometheus/tree/c954cd9d1d4e3530be2939d39d8633c38b70913f/util/pool | ||
// This package was forked to provide better protection against putting byte slices back into the pool that | ||
// did not originate from it. | ||
|
||
package pool | ||
|
||
import ( | ||
"context" | ||
"fmt" | ||
"sync" | ||
"time" | ||
|
||
"github.com/prometheus/client_golang/prometheus" | ||
"github.com/prometheus/client_golang/prometheus/promauto" | ||
"github.com/uber-go/atomic" | ||
) | ||
|
||
const ( | ||
queueLengthReportDuration = 15 * time.Second | ||
) | ||
|
||
var ( | ||
metricMissOver prometheus.Counter | ||
metricMissUnder prometheus.Counter | ||
metricQueryQueueLength = promauto.NewGauge(prometheus.GaugeOpts{ | ||
Namespace: "tempodb", | ||
Name: "work_queue_length", | ||
Help: "Current length of the work queue.", | ||
}) | ||
|
||
metricQueryQueueMax = promauto.NewGauge(prometheus.GaugeOpts{ | ||
Namespace: "tempodb", | ||
Name: "work_queue_max", | ||
Help: "Maximum number of items in the work queue.", | ||
}) | ||
) | ||
|
||
func init() { | ||
metricAllocOutPool := promauto.NewCounterVec(prometheus.CounterOpts{ | ||
Namespace: "tempo", | ||
Name: "ingester_prealloc_miss_bytes_total", | ||
Help: "The total number of alloc'ed bytes that missed the sync pools.", | ||
}, []string{"direction"}) | ||
type JobFunc func(ctx context.Context, payload interface{}) (interface{}, error) | ||
|
||
type result struct { | ||
data interface{} | ||
err error | ||
} | ||
|
||
type job struct { | ||
ctx context.Context | ||
payload interface{} | ||
fn JobFunc | ||
|
||
metricMissOver = metricAllocOutPool.WithLabelValues("over") | ||
metricMissUnder = metricAllocOutPool.WithLabelValues("under") | ||
wg *sync.WaitGroup | ||
resultsCh chan result | ||
stop *atomic.Bool | ||
} | ||
|
||
// Pool is a linearly bucketed pool for variably sized byte slices. | ||
type Pool struct { | ||
buckets []sync.Pool | ||
bktSize int | ||
minBucket int | ||
cfg *Config | ||
size *atomic.Int32 | ||
|
||
workQueue chan *job | ||
shutdownCh chan struct{} | ||
} | ||
|
||
// New returns a new Pool with size buckets for minSize to maxSize | ||
func New(minBucket, numBuckets, bktSize int) *Pool { | ||
if minBucket < 0 { | ||
panic("invalid min bucket size") | ||
} | ||
if bktSize < 1 { | ||
panic("invalid bucket size") | ||
func NewPool(cfg *Config) *Pool { | ||
if cfg == nil { | ||
cfg = defaultConfig() | ||
} | ||
if numBuckets < 1 { | ||
panic("invalid num buckets") | ||
|
||
q := make(chan *job, cfg.QueueDepth) | ||
p := &Pool{ | ||
cfg: cfg, | ||
workQueue: q, | ||
size: atomic.NewInt32(0), | ||
shutdownCh: make(chan struct{}), | ||
} | ||
|
||
return &Pool{ | ||
buckets: make([]sync.Pool, numBuckets), | ||
bktSize: bktSize, | ||
minBucket: minBucket, | ||
for i := 0; i < cfg.MaxWorkers; i++ { | ||
go p.worker(q) | ||
} | ||
|
||
p.reportQueueLength() | ||
|
||
metricQueryQueueMax.Set(float64(cfg.QueueDepth)) | ||
|
||
return p | ||
} | ||
|
||
// Get returns a new byte slices that fits the given size. | ||
func (p *Pool) Get(sz int) []byte { | ||
if sz < 0 { | ||
panic("requested negative size") | ||
} | ||
func (p *Pool) RunJobs(ctx context.Context, payloads []interface{}, fn JobFunc) ([]interface{}, []error, error) { | ||
ctx, cancel := context.WithCancel(ctx) | ||
defer cancel() | ||
|
||
// Find the right bucket. | ||
bkt := p.bucketFor(sz) | ||
totalJobs := len(payloads) | ||
|
||
if bkt < 0 { | ||
metricMissUnder.Add(float64(sz)) | ||
return make([]byte, 0, sz) | ||
// sanity check before we even attempt to start adding jobs | ||
if int(p.size.Load())+totalJobs > p.cfg.QueueDepth { | ||
return nil, nil, fmt.Errorf("queue doesn't have room for %d jobs", len(payloads)) | ||
} | ||
|
||
if bkt >= len(p.buckets) { | ||
metricMissOver.Add(float64(sz)) | ||
return make([]byte, 0, sz) | ||
resultsCh := make(chan result, totalJobs) // way for jobs to send back results | ||
stop := atomic.NewBool(false) // way to signal to the jobs to quit | ||
wg := &sync.WaitGroup{} // way to wait for all jobs to complete | ||
|
||
// add each job one at a time. even though we checked length above these might still fail | ||
for _, payload := range payloads { | ||
wg.Add(1) | ||
j := &job{ | ||
ctx: ctx, | ||
fn: fn, | ||
payload: payload, | ||
wg: wg, | ||
resultsCh: resultsCh, | ||
stop: stop, | ||
} | ||
|
||
select { | ||
case p.workQueue <- j: | ||
p.size.Inc() | ||
default: | ||
wg.Done() | ||
stop.Store(true) | ||
return nil, nil, fmt.Errorf("failed to add a job to work queue") | ||
} | ||
} | ||
|
||
b := p.buckets[bkt].Get() | ||
if b == nil { | ||
alignedSz := (bkt+1)*p.bktSize + p.minBucket | ||
b = make([]byte, 0, alignedSz) | ||
// wait for all jobs to finish | ||
wg.Wait() | ||
|
||
// close resultsCh | ||
close(resultsCh) | ||
|
||
// read all from results channel | ||
var data []interface{} | ||
var funcErrs []error | ||
for result := range resultsCh { | ||
if result.err != nil { | ||
funcErrs = append(funcErrs, result.err) | ||
} else { | ||
data = append(data, result.data) | ||
} | ||
} | ||
return b.([]byte) | ||
|
||
return data, funcErrs, nil | ||
} | ||
|
||
// Put adds a slice to the right bucket in the pool. | ||
func (p *Pool) Put(s []byte) int { | ||
c := cap(s) | ||
func (p *Pool) Shutdown() { | ||
close(p.workQueue) | ||
close(p.shutdownCh) | ||
} | ||
|
||
// valid slice? | ||
if (c-p.minBucket)%p.bktSize != 0 { | ||
return -1 | ||
func (p *Pool) worker(j <-chan *job) { | ||
for { | ||
select { | ||
case <-p.shutdownCh: | ||
return | ||
case j, ok := <-j: | ||
if !ok { | ||
return | ||
} | ||
runJob(j) | ||
p.size.Dec() | ||
} | ||
} | ||
bkt := p.bucketFor(c) // -1 puts the slice in the pool below. it will be larger than all requested slices for this bucket | ||
if bkt < 0 { | ||
return -1 | ||
} | ||
if bkt >= len(p.buckets) { | ||
return -1 | ||
} | ||
|
||
p.buckets[bkt].Put(s) // nolint: staticcheck | ||
} | ||
|
||
return bkt // for testing | ||
func (p *Pool) reportQueueLength() { | ||
ticker := time.NewTicker(queueLengthReportDuration) | ||
go func() { | ||
defer ticker.Stop() | ||
for { | ||
select { | ||
case <-ticker.C: | ||
metricQueryQueueLength.Set(float64(p.size.Load())) | ||
case <-p.shutdownCh: | ||
return | ||
} | ||
} | ||
}() | ||
} | ||
|
||
func (p *Pool) bucketFor(sz int) int { | ||
if sz <= p.minBucket { | ||
return -1 | ||
func runJob(job *job) { | ||
defer job.wg.Done() | ||
|
||
// bail in case we have been asked to stop | ||
if job.ctx.Err() != nil { | ||
return | ||
} | ||
|
||
return (sz - p.minBucket - 1) / p.bktSize | ||
// bail in case not all jobs could be enqueued | ||
if job.stop.Load() { | ||
return | ||
} | ||
|
||
data, err := job.fn(job.ctx, job.payload) | ||
if data != nil || err != nil { | ||
select { | ||
case job.resultsCh <- result{ | ||
data: data, | ||
err: err, | ||
}: | ||
default: // if we hit default it means that something else already returned a good result. /shrug | ||
} | ||
} | ||
} |
Oops, something went wrong.