Skip to content

Commit 4974182

Browse files
committed
Refactor worker pool for better scalability and readability
- Split `executeTask` function into `executeTaskWithTimeout` and `executeTaskWithoutTimeout` for better readability. - Move worker pool scaling logic into a separate goroutine that runs periodically, improving scalability and making the `dispatch` function simpler. - Add `retryCount` and `adjustInterval` fields to the `goPool` struct to support task retry and adjustable worker scaling intervals. - Update tests and README to reflect these changes. Signed-off-by: Daniel Hu <[email protected]>
1 parent 5a7dd13 commit 4974182

File tree

5 files changed

+107
-54
lines changed

5 files changed

+107
-54
lines changed

README.md

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,8 +22,6 @@ GoPool is a high-performance, feature-rich, and easy-to-use worker pool library
2222

2323
- **Task Retry**: GoPool provides a retry mechanism for failed tasks.
2424

25-
- **Task Progress Tracking**: GoPool provides task progress tracking.
26-
2725
- **Concurrency Control**: GoPool can control the number of concurrent tasks to prevent system overload.
2826

2927
- **Lock Customization**: GoPool supports different types of locks. You can use the built-in `sync.Mutex` or a custom lock such as `spinlock.SpinLock`.

gopool.go

Lines changed: 45 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -16,11 +16,13 @@ type goPool struct {
1616
minWorkers int
1717
workerStack []int
1818
taskQueue chan task
19+
retryCount int
1920
lock sync.Locker
2021
cond *sync.Cond
2122
timeout time.Duration
2223
resultCallback func(interface{})
2324
errorCallback func(error)
25+
adjustInterval time.Duration
2426
}
2527

2628
// NewGoPool creates a new pool of workers.
@@ -31,8 +33,10 @@ func NewGoPool(maxWorkers int, opts ...Option) *goPool {
3133
workers: make([]*worker, maxWorkers),
3234
workerStack: make([]int, maxWorkers),
3335
taskQueue: make(chan task, 1e6),
36+
retryCount: 0,
3437
lock: new(sync.Mutex),
3538
timeout: 0,
39+
adjustInterval: 1 * time.Second,
3640
}
3741
for _, opt := range opts {
3842
opt(pool)
@@ -46,6 +50,7 @@ func NewGoPool(maxWorkers int, opts ...Option) *goPool {
4650
pool.workerStack[i] = i
4751
worker.start(pool, i)
4852
}
53+
go pool.adjustWorkers()
4954
go pool.dispatch()
5055
return pool
5156
}
@@ -59,7 +64,7 @@ func (p *goPool) AddTask(t task) {
5964
func (p *goPool) Release() {
6065
close(p.taskQueue)
6166
p.cond.L.Lock()
62-
for len(p.workerStack) != p.maxWorkers {
67+
for len(p.workerStack) != p.minWorkers {
6368
p.cond.Wait()
6469
}
6570
p.cond.L.Unlock()
@@ -85,6 +90,31 @@ func (p *goPool) pushWorker(workerIndex int) {
8590
p.cond.Signal()
8691
}
8792

93+
func (p *goPool) adjustWorkers() {
94+
ticker := time.NewTicker(p.adjustInterval)
95+
defer ticker.Stop()
96+
97+
for range ticker.C {
98+
p.cond.L.Lock()
99+
if len(p.taskQueue) > (p.maxWorkers-p.minWorkers)/2+p.minWorkers && len(p.workerStack) < p.maxWorkers {
100+
// Double the number of workers until it reaches the maximum
101+
newWorkers := min(len(p.workerStack)*2, p.maxWorkers) - len(p.workerStack)
102+
for i := 0; i < newWorkers; i++ {
103+
worker := newWorker()
104+
p.workers = append(p.workers, worker)
105+
p.workerStack = append(p.workerStack, len(p.workers)-1)
106+
worker.start(p, len(p.workers)-1)
107+
}
108+
} else if len(p.taskQueue) < p.minWorkers && len(p.workerStack) > p.minWorkers {
109+
// Halve the number of workers until it reaches the minimum
110+
removeWorkers := max((len(p.workerStack)-p.minWorkers)/2, p.minWorkers)
111+
p.workers = p.workers[:len(p.workers)-removeWorkers]
112+
p.workerStack = p.workerStack[:len(p.workerStack)-removeWorkers]
113+
}
114+
p.cond.L.Unlock()
115+
}
116+
}
117+
88118
func (p *goPool) dispatch() {
89119
for t := range p.taskQueue {
90120
p.cond.L.Lock()
@@ -94,14 +124,19 @@ func (p *goPool) dispatch() {
94124
p.cond.L.Unlock()
95125
workerIndex := p.popWorker()
96126
p.workers[workerIndex].taskQueue <- t
97-
if len(p.taskQueue) > (p.maxWorkers-p.minWorkers)/2+p.minWorkers && len(p.workerStack) < p.maxWorkers {
98-
worker := newWorker()
99-
p.workers = append(p.workers, worker)
100-
p.workerStack = append(p.workerStack, len(p.workers)-1)
101-
worker.start(p, len(p.workers)-1)
102-
} else if len(p.taskQueue) < p.minWorkers && len(p.workerStack) > p.minWorkers {
103-
p.workers = p.workers[:len(p.workers)-1]
104-
p.workerStack = p.workerStack[:len(p.workerStack)-1]
105-
}
106127
}
107128
}
129+
130+
func min(a, b int) int {
131+
if a < b {
132+
return a
133+
}
134+
return b
135+
}
136+
137+
func max(a, b int) int {
138+
if a > b {
139+
return a
140+
}
141+
return b
142+
}

gopool_test.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ func TestGoPoolWithSpinLock(t *testing.T) {
3434
func BenchmarkGoPoolWithMutex(b *testing.B) {
3535
var wg sync.WaitGroup
3636
var taskNum = int(1e6)
37-
pool := NewGoPool(5e4, WithLock(new(sync.Mutex)))
37+
pool := NewGoPool(1e4, WithLock(new(sync.Mutex)))
3838

3939
b.ResetTimer()
4040
for i := 0; i < b.N; i++ {
@@ -55,7 +55,7 @@ func BenchmarkGoPoolWithMutex(b *testing.B) {
5555
func BenchmarkGoPoolWithSpinLock(b *testing.B) {
5656
var wg sync.WaitGroup
5757
var taskNum = int(1e6)
58-
pool := NewGoPool(5e4, WithLock(new(spinlock.SpinLock)))
58+
pool := NewGoPool(1e4, WithLock(new(spinlock.SpinLock)))
5959

6060
b.ResetTimer()
6161
for i := 0; i < b.N; i++ {

option.go

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,3 +43,10 @@ func WithErrorCallback(callback func(error)) Option {
4343
p.errorCallback = callback
4444
}
4545
}
46+
47+
// WithRetryCount sets the retry count for the pool.
48+
func WithRetryCount(retryCount int) Option {
49+
return func(p *goPool) {
50+
p.retryCount = retryCount
51+
}
52+
}

worker.go

Lines changed: 53 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -20,49 +20,62 @@ func (w *worker) start(pool *goPool, workerIndex int) {
2020
go func() {
2121
for t := range w.taskQueue {
2222
if t != nil {
23-
var result interface{}
24-
var err error
23+
result, err := w.executeTask(t, pool)
24+
w.handleResult(result, err, pool)
25+
}
26+
pool.pushWorker(workerIndex)
27+
}
28+
}()
29+
}
2530

26-
if pool.timeout > 0 {
27-
// Create a context with timeout
28-
ctx, cancel := context.WithTimeout(context.Background(), pool.timeout)
29-
defer cancel()
31+
func (w *worker) executeTask(t task, pool *goPool) (result interface{}, err error) {
32+
for i := 0; i <= pool.retryCount; i++ {
33+
if pool.timeout > 0 {
34+
result, err = w.executeTaskWithTimeout(t, pool)
35+
} else {
36+
result, err = w.executeTaskWithoutTimeout(t, pool)
37+
}
38+
if err == nil || i == pool.retryCount {
39+
return result, err
40+
}
41+
}
42+
return
43+
}
3044

31-
// Create a channel to receive the result of the task
32-
done := make(chan struct{})
45+
func (w *worker) executeTaskWithTimeout(t task, pool *goPool) (result interface{}, err error) {
46+
// Create a context with timeout
47+
ctx, cancel := context.WithTimeout(context.Background(), pool.timeout)
48+
defer cancel()
3349

34-
// Run the task in a separate goroutine
35-
go func() {
36-
result, err = t()
37-
close(done)
38-
}()
50+
// Create a channel to receive the result of the task
51+
done := make(chan struct{})
3952

40-
// Wait for the task to finish or for the context to timeout
41-
select {
42-
case <-done:
43-
// The task finished successfully
44-
if err != nil && pool.errorCallback != nil {
45-
pool.errorCallback(err)
46-
} else if pool.resultCallback != nil {
47-
pool.resultCallback(result)
48-
}
49-
case <-ctx.Done():
50-
// The context timed out, the task took too long
51-
if pool.errorCallback != nil {
52-
pool.errorCallback(fmt.Errorf("Task timed out"))
53-
}
54-
}
55-
} else {
56-
// If timeout is not set or is zero, just run the task
57-
result, err = t()
58-
if err != nil && pool.errorCallback != nil {
59-
pool.errorCallback(err)
60-
} else if pool.resultCallback != nil {
61-
pool.resultCallback(result)
62-
}
63-
}
64-
}
65-
pool.pushWorker(workerIndex)
66-
}
53+
// Run the task in a separate goroutine
54+
go func() {
55+
result, err = t()
56+
close(done)
6757
}()
58+
59+
// Wait for the task to finish or for the context to timeout
60+
select {
61+
case <-done:
62+
// The task finished successfully
63+
return result, err
64+
case <-ctx.Done():
65+
// The context timed out, the task took too long
66+
return nil, fmt.Errorf("Task timed out")
67+
}
68+
}
69+
70+
func (w *worker) executeTaskWithoutTimeout(t task, pool *goPool) (result interface{}, err error) {
71+
// If timeout is not set or is zero, just run the task
72+
return t()
73+
}
74+
75+
func (w *worker) handleResult(result interface{}, err error, pool *goPool) {
76+
if err != nil && pool.errorCallback != nil {
77+
pool.errorCallback(err)
78+
} else if pool.resultCallback != nil {
79+
pool.resultCallback(result)
80+
}
6881
}

0 commit comments

Comments
 (0)