Skip to content

Commit 5eb0e76

Browse files
authored
Merge pull request #1 from uber-go/master
Merge branch 'master' from original repo
2 parents c9b24f8 + b379e13 commit 5eb0e76

File tree

9 files changed

+474
-35
lines changed

9 files changed

+474
-35
lines changed

annotated_test.go

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1485,7 +1485,8 @@ func assertApp(
14851485
invoked *bool,
14861486
) {
14871487
t.Helper()
1488-
ctx := context.Background()
1488+
ctx, cancel := context.WithCancel(context.Background())
1489+
defer cancel()
14891490
assert.False(t, *started)
14901491
require.NoError(t, app.Start(ctx))
14911492
assert.True(t, *started)
@@ -1517,8 +1518,11 @@ func TestHookAnnotations(t *testing.T) {
15171518
t.Run("with hook on invoke", func(t *testing.T) {
15181519
t.Parallel()
15191520

1520-
var started bool
1521-
var invoked bool
1521+
var (
1522+
started bool
1523+
stopped bool
1524+
invoked bool
1525+
)
15221526
hook := fx.Annotate(
15231527
func() {
15241528
invoked = true
@@ -1527,10 +1531,14 @@ func TestHookAnnotations(t *testing.T) {
15271531
started = true
15281532
return nil
15291533
}),
1534+
fx.OnStop(func(context.Context) error {
1535+
stopped = true
1536+
return nil
1537+
}),
15301538
)
15311539
app := fxtest.New(t, fx.Invoke(hook))
15321540

1533-
assertApp(t, app, &started, nil, &invoked)
1541+
assertApp(t, app, &started, &stopped, &invoked)
15341542
})
15351543

15361544
t.Run("depend on result interface of target", func(t *testing.T) {

app.go

Lines changed: 39 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -619,9 +619,13 @@ func (app *App) Start(ctx context.Context) (err error) {
619619
})
620620
}
621621

622-
func (app *App) start(ctx context.Context) error {
623-
if err := app.lifecycle.Start(ctx); err != nil {
624-
// Start failed, rolling back.
622+
// withRollback will execute an anonymous function with a given context.
623+
// if the anon func returns an error, rollback methods will be called and related events emitted
624+
func (app *App) withRollback(
625+
ctx context.Context,
626+
f func(context.Context) error,
627+
) error {
628+
if err := f(ctx); err != nil {
625629
app.log().LogEvent(&fxevent.RollingBack{StartErr: err})
626630

627631
stopErr := app.lifecycle.Stop(ctx)
@@ -633,9 +637,20 @@ func (app *App) start(ctx context.Context) error {
633637

634638
return err
635639
}
640+
636641
return nil
637642
}
638643

644+
func (app *App) start(ctx context.Context) error {
645+
return app.withRollback(ctx, func(ctx context.Context) error {
646+
if err := app.lifecycle.Start(ctx); err != nil {
647+
return err
648+
}
649+
app.receivers.Start(ctx)
650+
return nil
651+
})
652+
}
653+
639654
// Stop gracefully stops the application. It executes any registered OnStop
640655
// hooks in reverse order, so that each constructor's stop hooks are called
641656
// before its dependencies' stop hooks.
@@ -648,9 +663,14 @@ func (app *App) Stop(ctx context.Context) (err error) {
648663
app.log().LogEvent(&fxevent.Stopped{Err: err})
649664
}()
650665

666+
cb := func(ctx context.Context) error {
667+
defer app.receivers.Stop(ctx)
668+
return app.lifecycle.Stop(ctx)
669+
}
670+
651671
return withTimeout(ctx, &withTimeoutParams{
652672
hook: _onStopHook,
653-
callback: app.lifecycle.Stop,
673+
callback: cb,
654674
lifecycle: app.lifecycle,
655675
log: app.log(),
656676
})
@@ -663,10 +683,25 @@ func (app *App) Stop(ctx context.Context) (err error) {
663683
//
664684
// Alternatively, a signal can be broadcast to all done channels manually by
665685
// using the Shutdown functionality (see the Shutdowner documentation for details).
686+
//
687+
// Note: The channel Done returns will not receive a signal unless the application
688+
// as been started via Start or Run.
666689
func (app *App) Done() <-chan os.Signal {
667690
return app.receivers.Done()
668691
}
669692

693+
// Wait returns a channel of [ShutdownSignal] to block on after starting the
694+
// application and function, similar to [App.Done], but with a minor difference.
695+
// Should an ExitCode be provided as a [ShutdownOption] to
696+
// the Shutdowner Shutdown method, the exit code will be available as part
697+
// of the ShutdownSignal struct.
698+
//
699+
// Should the app receive a SIGTERM or SIGINT, the given
700+
// signal will be populated in the ShutdownSignal struct.
701+
func (app *App) Wait() <-chan ShutdownSignal {
702+
return app.receivers.Wait()
703+
}
704+
670705
// StartTimeout returns the configured startup timeout. Apps default to using
671706
// DefaultTimeout, but users can configure this behavior using the
672707
// StartTimeout option.

app_test.go

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1281,6 +1281,29 @@ func TestAppStart(t *testing.T) {
12811281
err := app.Start(context.Background()).Error()
12821282
assert.Contains(t, err, "OnStart hook added by go.uber.org/fx_test.TestAppStart.func10.1 failed: goroutine exited without returning")
12831283
})
1284+
1285+
t.Run("StartTwiceWithHooksErrors", func(t *testing.T) {
1286+
t.Parallel()
1287+
1288+
ctx, cancel := context.WithCancel(context.Background())
1289+
defer cancel()
1290+
1291+
app := fxtest.New(t,
1292+
Invoke(func(lc Lifecycle) {
1293+
lc.Append(Hook{
1294+
OnStart: func(ctx context.Context) error { return nil },
1295+
OnStop: func(ctx context.Context) error { return nil },
1296+
})
1297+
}),
1298+
)
1299+
assert.NoError(t, app.Start(ctx))
1300+
err := app.Start(ctx)
1301+
if assert.Error(t, err) {
1302+
assert.ErrorContains(t, err, "attempted to start lifecycle when in state: started")
1303+
}
1304+
app.Stop(ctx)
1305+
assert.NoError(t, app.Start(ctx))
1306+
})
12841307
}
12851308

12861309
func TestAppStop(t *testing.T) {

internal/lifecycle/lifecycle.go

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -123,10 +123,38 @@ type Hook struct {
123123
callerFrame fxreflect.Frame
124124
}
125125

126+
type appState int
127+
128+
const (
129+
stopped appState = iota
130+
starting
131+
incompleteStart
132+
started
133+
stopping
134+
)
135+
136+
func (as appState) String() string {
137+
switch as {
138+
case stopped:
139+
return "stopped"
140+
case starting:
141+
return "starting"
142+
case incompleteStart:
143+
return "incompleteStart"
144+
case started:
145+
return "started"
146+
case stopping:
147+
return "stopping"
148+
default:
149+
return "invalidState"
150+
}
151+
}
152+
126153
// Lifecycle coordinates application lifecycle hooks.
127154
type Lifecycle struct {
128155
clock fxclock.Clock
129156
logger fxevent.Logger
157+
state appState
130158
hooks []Hook
131159
numStarted int
132160
startRecords HookRecords
@@ -157,9 +185,23 @@ func (l *Lifecycle) Start(ctx context.Context) error {
157185
}
158186

159187
l.mu.Lock()
188+
if l.state != stopped {
189+
defer l.mu.Unlock()
190+
return fmt.Errorf("attempted to start lifecycle when in state: %v", l.state)
191+
}
192+
l.numStarted = 0
193+
l.state = starting
194+
160195
l.startRecords = make(HookRecords, 0, len(l.hooks))
161196
l.mu.Unlock()
162197

198+
var returnState appState = incompleteStart
199+
defer func() {
200+
l.mu.Lock()
201+
l.state = returnState
202+
l.mu.Unlock()
203+
}()
204+
163205
for _, hook := range l.hooks {
164206
// if ctx has cancelled, bail out of the loop.
165207
if err := ctx.Err(); err != nil {
@@ -187,6 +229,7 @@ func (l *Lifecycle) Start(ctx context.Context) error {
187229
l.numStarted++
188230
}
189231

232+
returnState = started
190233
return nil
191234
}
192235

@@ -221,6 +264,20 @@ func (l *Lifecycle) Stop(ctx context.Context) error {
221264
return errors.New("called OnStop with nil context")
222265
}
223266

267+
l.mu.Lock()
268+
if l.state != started && l.state != incompleteStart {
269+
defer l.mu.Unlock()
270+
return fmt.Errorf("attempted to stop lifecycle when in state: %v", l.state)
271+
}
272+
l.state = stopping
273+
l.mu.Unlock()
274+
275+
defer func() {
276+
l.mu.Lock()
277+
l.state = stopped
278+
l.mu.Unlock()
279+
}()
280+
224281
l.mu.Lock()
225282
l.stopRecords = make(HookRecords, 0, l.numStarted)
226283
l.mu.Unlock()

internal/lifecycle/lifecycle_test.go

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,7 @@ func TestLifecycleStart(t *testing.T) {
7171
assert.NoError(t, l.Start(context.Background()))
7272
assert.Equal(t, 2, count)
7373
})
74+
7475
t.Run("ErrHaltsChainAndRollsBack", func(t *testing.T) {
7576
t.Parallel()
7677

@@ -143,6 +144,18 @@ func TestLifecycleStart(t *testing.T) {
143144
// stop hooks.
144145
require.NoError(t, l.Stop(ctx))
145146
})
147+
148+
t.Run("StartWhileStartedErrors", func(t *testing.T) {
149+
t.Parallel()
150+
151+
l := New(testLogger(t), fxclock.System)
152+
assert.NoError(t, l.Start(context.Background()))
153+
err := l.Start(context.Background())
154+
require.Error(t, err)
155+
assert.Contains(t, err.Error(), "attempted to start lifecycle when in state: started")
156+
assert.NoError(t, l.Stop(context.Background()))
157+
assert.NoError(t, l.Start(context.Background()))
158+
})
146159
}
147160

148161
func TestLifecycleStop(t *testing.T) {
@@ -152,6 +165,7 @@ func TestLifecycleStop(t *testing.T) {
152165
t.Parallel()
153166

154167
l := New(testLogger(t), fxclock.System)
168+
l.Start(context.Background())
155169
assert.Nil(t, l.Stop(context.Background()), "no lifecycle hooks should have resulted in stop returning nil")
156170
})
157171

@@ -317,6 +331,16 @@ func TestLifecycleStop(t *testing.T) {
317331
assert.Contains(t, err.Error(), "called OnStop with nil context")
318332

319333
})
334+
335+
t.Run("StopWhileStoppedErrors", func(t *testing.T) {
336+
t.Parallel()
337+
338+
l := New(testLogger(t), fxclock.System)
339+
err := l.Stop(context.Background())
340+
require.Error(t, err)
341+
assert.Contains(t, err.Error(), "attempted to stop lifecycle when in state: stopped")
342+
})
343+
320344
}
321345

322346
func TestHookRecordsFormat(t *testing.T) {

shutdown.go

Lines changed: 61 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,11 @@
2020

2121
package fx
2222

23+
import (
24+
"context"
25+
"time"
26+
)
27+
2328
// Shutdowner provides a method that can manually trigger the shutdown of the
2429
// application by sending a signal to all open Done channels. Shutdowner works
2530
// on applications using Run as well as Start, Done, and Stop. The Shutdowner is
@@ -34,8 +39,42 @@ type ShutdownOption interface {
3439
apply(*shutdowner)
3540
}
3641

42+
type exitCodeOption int
43+
44+
func (code exitCodeOption) apply(s *shutdowner) {
45+
s.exitCode = int(code)
46+
}
47+
48+
var _ ShutdownOption = exitCodeOption(0)
49+
50+
// ExitCode is a [ShutdownOption] that may be passed to the Shutdown method of the
51+
// [Shutdowner] interface.
52+
// The given integer exit code will be broadcasted to any receiver waiting
53+
// on a [ShutdownSignal] from the [Wait] method.
54+
func ExitCode(code int) ShutdownOption {
55+
return exitCodeOption(code)
56+
}
57+
58+
type shutdownTimeoutOption time.Duration
59+
60+
func (to shutdownTimeoutOption) apply(s *shutdowner) {
61+
s.shutdownTimeout = time.Duration(to)
62+
}
63+
64+
var _ ShutdownOption = shutdownTimeoutOption(0)
65+
66+
// ShutdownTimeout is a [ShutdownOption] that allows users to specify a timeout
67+
// for a given call to Shutdown method of the [Shutdowner] interface. As the
68+
// Shutdown method will block while waiting for a signal receiver relay
69+
// goroutine to stop.
70+
func ShutdownTimeout(timeout time.Duration) ShutdownOption {
71+
return shutdownTimeoutOption(timeout)
72+
}
73+
3774
type shutdowner struct {
38-
app *App
75+
app *App
76+
exitCode int
77+
shutdownTimeout time.Duration
3978
}
4079

4180
// Shutdown broadcasts a signal to all of the application's Done channels
@@ -44,7 +83,27 @@ type shutdowner struct {
4483
// In practice this means Shutdowner.Shutdown should not be called from an
4584
// fx.Invoke, but from a fx.Lifecycle.OnStart hook.
4685
func (s *shutdowner) Shutdown(opts ...ShutdownOption) error {
47-
return s.app.receivers.Broadcast(ShutdownSignal{Signal: _sigTERM})
86+
for _, opt := range opts {
87+
opt.apply(s)
88+
}
89+
90+
ctx := context.Background()
91+
92+
if s.shutdownTimeout != time.Duration(0) {
93+
c, cancel := context.WithTimeout(
94+
context.Background(),
95+
s.shutdownTimeout,
96+
)
97+
defer cancel()
98+
ctx = c
99+
}
100+
101+
defer s.app.receivers.Stop(ctx)
102+
103+
return s.app.receivers.Broadcast(ShutdownSignal{
104+
Signal: _sigTERM,
105+
ExitCode: s.exitCode,
106+
})
48107
}
49108

50109
func (app *App) shutdowner() Shutdowner {

0 commit comments

Comments
 (0)