Skip to content

Commit

Permalink
worker builder jobs
Browse files Browse the repository at this point in the history
  • Loading branch information
hmoragrega committed Feb 7, 2021
1 parent 082fd5e commit 0a320a4
Show file tree
Hide file tree
Showing 2 changed files with 43 additions and 13 deletions.
44 changes: 37 additions & 7 deletions pool.go
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,25 @@ func (f JobFunc) Do(ctx context.Context) error {
return f(ctx)
}

// JobBuilder is a job that needs to be build on the
// initialization for each worker.
type JobBuilder interface {
// New generates a new job for each workers.
//
// Its useful for jobs that need to share data between
// calls
New() Job
}

// JobBuilderFunc is a type of job that shares
// requires an initialization on each worker.
type JobBuilderFunc func() Job

// New builds the new job.
func (f JobBuilderFunc) New() Job {
return f()
}

// Middleware is a function that wraps the job and can
// be used to extend the functionality of the pool.
type Middleware interface {
Expand Down Expand Up @@ -170,9 +189,9 @@ type Pool struct {
max int

// job and its workers.
job Job
mws []Middleware
workers []*worker
jobBuilder JobBuilder
mws []Middleware
workers []*worker

// Current pool state.
started bool
Expand All @@ -197,8 +216,19 @@ type Pool struct {
mx sync.RWMutex
}

// StartBuilder launches the workers and keeps them running until the pool is closed.
func (p *Pool) StartBuilder(jobBuilder JobBuilder) error {
return p.start(jobBuilder)
}

// Start launches the workers and keeps them running until the pool is closed.
func (p *Pool) Start(job Job) error {
return p.start(JobBuilderFunc(func() Job {
return job
}))
}

func (p *Pool) start(jobBuilder JobBuilder) error {
p.mx.Lock()
defer p.mx.Unlock()

Expand All @@ -213,8 +243,8 @@ func (p *Pool) Start(job Job) error {
initial = 1
}

p.jobBuilder = jobBuilder
p.started = true
p.job = Wrap(job, p.mws...)
p.running = make(chan int)
p.done = make(chan struct{})
p.workers = make([]*worker, initial)
Expand Down Expand Up @@ -372,9 +402,9 @@ func (p *Pool) close() error {
return nil
}

// CloseWIthTimeout closes the pool waiting
// CloseWithTimeout closes the pool waiting
// for a certain amount of time.
func (p *Pool) CloseWIthTimeout(timeout time.Duration) error {
func (p *Pool) CloseWithTimeout(timeout time.Duration) error {
ctx, cancel := context.WithTimeout(context.Background(), timeout)
defer cancel()

Expand All @@ -392,7 +422,7 @@ func (p *Pool) newWorker() *worker {
defer func() {
p.running <- -1
}()
w.work(ctx, p.job)
w.work(ctx, Wrap(p.jobBuilder.New(), p.mws...))
}()

return w
Expand Down
12 changes: 6 additions & 6 deletions pool_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,7 @@ func TestPool_Start(t *testing.T) {
t.Run(tc.name, func(t *testing.T) {
p := Must(NewWithConfig(tc.config))
t.Cleanup(func() {
if err := p.CloseWIthTimeout(time.Second); err != nil {
if err := p.CloseWithTimeout(time.Second); err != nil {
t.Fatal("cannot stop pool", err)
}
})
Expand All @@ -128,7 +128,7 @@ func TestPool_Start(t *testing.T) {
func TestPool_StartErrors(t *testing.T) {
t.Run("pool closed", func(t *testing.T) {
var p Pool
if err := p.CloseWIthTimeout(time.Second); err != nil {
if err := p.CloseWithTimeout(time.Second); err != nil {
t.Fatalf("unexpected error closing an uninitialized pool; got %+v", err)
}

Expand Down Expand Up @@ -176,7 +176,7 @@ func TestPool_More(t *testing.T) {
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
t.Cleanup(func() {
_ = tc.pool.CloseWIthTimeout(time.Second)
_ = tc.pool.CloseWithTimeout(time.Second)
})
if err := tc.pool.Start(dummyJob); err != nil {
t.Fatalf("unexpected error starting pool: %+v", err)
Expand Down Expand Up @@ -243,7 +243,7 @@ func TestPool_Less(t *testing.T) {
t.Run("error minimum number of workers reached", func(t *testing.T) {
var p Pool
t.Cleanup(func() {
if err := p.CloseWIthTimeout(time.Second); err != nil {
if err := p.CloseWithTimeout(time.Second); err != nil {
t.Fatal("cannot stop pool", err)
}
})
Expand Down Expand Up @@ -300,7 +300,7 @@ func TestPool_Close(t *testing.T) {
t.Fatalf("unexpected error starting pool: %+v", err)
}

got := p.CloseWIthTimeout(time.Second)
got := p.CloseWithTimeout(time.Second)
if !errors.Is(got, nil) {
t.Fatalf("unexpected error closing pool: %+v, want nil", got)
}
Expand All @@ -321,7 +321,7 @@ func TestPool_Close(t *testing.T) {
}

<-running
got := p.CloseWIthTimeout(25 * time.Millisecond)
got := p.CloseWithTimeout(25 * time.Millisecond)
if !errors.Is(got, context.DeadlineExceeded) {
t.Fatalf("unexpected error closing pool: %+v, want %+v", got, context.DeadlineExceeded)
}
Expand Down

0 comments on commit 0a320a4

Please sign in to comment.