Skip to content

Commit

Permalink
Merge pull request #1994 from dearchap/issue_1927
Browse files Browse the repository at this point in the history
Feature:(issue_1927) Add ability for before handler to return new context
  • Loading branch information
dearchap authored Oct 30, 2024
2 parents 083c12d + e069edf commit 7ec374f
Show file tree
Hide file tree
Showing 6 changed files with 58 additions and 21 deletions.
4 changes: 3 additions & 1 deletion command.go
Original file line number Diff line number Diff line change
Expand Up @@ -562,9 +562,11 @@ func (cmd *Command) Run(ctx context.Context, osArgs []string) (deferErr error) {
}

if cmd.Before != nil && !cmd.Root().shellCompletion {
if err := cmd.Before(ctx, cmd); err != nil {
if bctx, err := cmd.Before(ctx, cmd); err != nil {
deferErr = cmd.handleExitCoder(ctx, err)
return deferErr
} else if bctx != nil {
ctx = bctx
}
}

Expand Down
65 changes: 50 additions & 15 deletions command_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -268,8 +268,8 @@ func TestParseAndRunShortOpts(t *testing.T) {
func TestCommand_Run_DoesNotOverwriteErrorFromBefore(t *testing.T) {
cmd := &Command{
Name: "bar",
Before: func(context.Context, *Command) error {
return fmt.Errorf("before error")
Before: func(context.Context, *Command) (context.Context, error) {
return nil, fmt.Errorf("before error")
},
After: func(context.Context, *Command) error {
return fmt.Errorf("after error")
Expand All @@ -289,16 +289,17 @@ func TestCommand_Run_BeforeSavesMetadata(t *testing.T) {

cmd := &Command{
Name: "bar",
Before: func(_ context.Context, cmd *Command) error {
Before: func(ctx context.Context, cmd *Command) (context.Context, error) {
cmd.Metadata["msg"] = "hello world"
return nil
return nil, nil
},
Action: func(_ context.Context, cmd *Command) error {
Action: func(ctx context.Context, cmd *Command) error {
msg, ok := cmd.Metadata["msg"]
if !ok {
return errors.New("msg not found")
}
receivedMsgFromAction = msg.(string)

return nil
},
After: func(_ context.Context, cmd *Command) error {
Expand All @@ -316,6 +317,40 @@ func TestCommand_Run_BeforeSavesMetadata(t *testing.T) {
require.Equal(t, "hello world", receivedMsgFromAfter)
}

func TestCommand_Run_BeforeReturnNewContext(t *testing.T) {
var receivedValFromAction, receivedValFromAfter string
type key string

bkey := key("bkey")

cmd := &Command{
Name: "bar",
Before: func(ctx context.Context, cmd *Command) (context.Context, error) {
return context.WithValue(ctx, bkey, "bval"), nil
},
Action: func(ctx context.Context, cmd *Command) error {
if val := ctx.Value(bkey); val == nil {
return errors.New("bkey value not found")
} else {
receivedValFromAction = val.(string)
}
return nil
},
After: func(ctx context.Context, cmd *Command) error {
if val := ctx.Value(bkey); val == nil {
return errors.New("bkey value not found")
} else {
receivedValFromAfter = val.(string)
}
return nil
},
}

require.NoError(t, cmd.Run(buildTestContext(t), []string{"foo", "bar"}))
require.Equal(t, "bval", receivedValFromAfter)
require.Equal(t, "bval", receivedValFromAction)
}

func TestCommand_OnUsageError_hasCommandContext(t *testing.T) {
cmd := &Command{
Name: "bar",
Expand Down Expand Up @@ -1340,15 +1375,15 @@ func TestCommand_BeforeFunc(t *testing.T) {
var err error

cmd := &Command{
Before: func(_ context.Context, cmd *Command) error {
Before: func(_ context.Context, cmd *Command) (context.Context, error) {
counts.Total++
counts.Before = counts.Total
s := cmd.String("opt")
if s == "fail" {
return beforeError
return nil, beforeError
}

return nil
return nil, nil
},
Commands: []*Command{
{
Expand Down Expand Up @@ -1411,10 +1446,10 @@ func TestCommand_BeforeAfterFuncShellCompletion(t *testing.T) {

cmd := &Command{
EnableShellCompletion: true,
Before: func(context.Context, *Command) error {
Before: func(context.Context, *Command) (context.Context, error) {
counts.Total++
counts.Before = counts.Total
return nil
return nil, nil
},
After: func(context.Context, *Command) error {
counts.Total++
Expand Down Expand Up @@ -1758,10 +1793,10 @@ func TestCommand_OrderOfOperations(t *testing.T) {
Writer: io.Discard,
}

beforeNoError := func(context.Context, *Command) error {
beforeNoError := func(context.Context, *Command) (context.Context, error) {
counts.Total++
counts.Before = counts.Total
return nil
return nil, nil
}

cmd.Before = beforeNoError
Expand Down Expand Up @@ -1838,10 +1873,10 @@ func TestCommand_OrderOfOperations(t *testing.T) {

t.Run("before with error", func(t *testing.T) {
cmd, counts := buildCmdCounts()
cmd.Before = func(context.Context, *Command) error {
cmd.Before = func(context.Context, *Command) (context.Context, error) {
counts.Total++
counts.Before = counts.Total
return errors.New("hay Before")
return nil, errors.New("hay Before")
}

r := require.New(t)
Expand Down Expand Up @@ -2213,7 +2248,7 @@ func TestCommand_Run_SubcommandDoesNotOverwriteErrorFromBefore(t *testing.T) {
},
},
Name: "bar",
Before: func(context.Context, *Command) error { return fmt.Errorf("before error") },
Before: func(context.Context, *Command) (context.Context, error) { return nil, fmt.Errorf("before error") },
After: func(context.Context, *Command) error { return fmt.Errorf("after error") },
},
},
Expand Down
4 changes: 2 additions & 2 deletions docs/v3/examples/full-api-example.md
Original file line number Diff line number Diff line change
Expand Up @@ -111,9 +111,9 @@ func main() {
ShellComplete: func(ctx context.Context, cmd *cli.Command) {
fmt.Fprintf(cmd.Root().Writer, "--better\n")
},
Before: func(ctx context.Context, cmd *cli.Command) error {
Before: func(ctx context.Context, cmd *cli.Command) (context.Context, error) {
fmt.Fprintf(cmd.Root().Writer, "brace for impact\n")
return nil
return nil, nil
},
After: func(ctx context.Context, cmd *cli.Command) error {
fmt.Fprintf(cmd.Root().Writer, "did we lose anyone?\n")
Expand Down
2 changes: 1 addition & 1 deletion funcs.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ type ShellCompleteFunc func(context.Context, *Command)
// BeforeFunc is an action that executes prior to any subcommands being run once
// the context is ready. If a non-nil error is returned, no subcommands are
// run.
type BeforeFunc func(context.Context, *Command) error
type BeforeFunc func(context.Context, *Command) (context.Context, error)

// AfterFunc is an action that executes after any subcommands are run and have
// finished. The AfterFunc is run even if Action() panics.
Expand Down
2 changes: 1 addition & 1 deletion godoc-current.txt
Original file line number Diff line number Diff line change
Expand Up @@ -247,7 +247,7 @@ func (a *ArgumentBase[T, C, VC]) Parse(s []string) ([]string, error)

func (a *ArgumentBase[T, C, VC]) Usage() string

type BeforeFunc func(context.Context, *Command) error
type BeforeFunc func(context.Context, *Command) (context.Context, error)
BeforeFunc is an action that executes prior to any subcommands being run
once the context is ready. If a non-nil error is returned, no subcommands
are run.
Expand Down
2 changes: 1 addition & 1 deletion testdata/godoc-v3.x.txt
Original file line number Diff line number Diff line change
Expand Up @@ -247,7 +247,7 @@ func (a *ArgumentBase[T, C, VC]) Parse(s []string) ([]string, error)

func (a *ArgumentBase[T, C, VC]) Usage() string

type BeforeFunc func(context.Context, *Command) error
type BeforeFunc func(context.Context, *Command) (context.Context, error)
BeforeFunc is an action that executes prior to any subcommands being run
once the context is ready. If a non-nil error is returned, no subcommands
are run.
Expand Down

0 comments on commit 7ec374f

Please sign in to comment.