From 3ff59958307b6c8f33857865425b8d5b273d8d8e Mon Sep 17 00:00:00 2001 From: Naveen Gogineni Date: Sat, 26 Oct 2024 07:50:36 -0400 Subject: [PATCH 1/3] Feature:(issue_1927) Add return context for before handler --- command.go | 4 +++- command_test.go | 28 ++++++++++++++-------------- funcs.go | 2 +- godoc-current.txt | 2 +- testdata/godoc-v3.x.txt | 2 +- 5 files changed, 20 insertions(+), 18 deletions(-) diff --git a/command.go b/command.go index 9fa8f9dcb0..91f238f2a5 100644 --- a/command.go +++ b/command.go @@ -562,9 +562,11 @@ func (cmd *Command) Run(ctx context.Context, osArgs []string) (deferErr error) { } if cmd.Before != nil && !cmd.Root().shellCompletion { - if err := cmd.Before(ctx, cmd); err != nil { + if bctx, err := cmd.Before(ctx, cmd); err != nil { deferErr = cmd.handleExitCoder(ctx, err) return deferErr + } else if bctx != nil { + ctx = bctx } } diff --git a/command_test.go b/command_test.go index b8ba1eb658..160b9b16d4 100644 --- a/command_test.go +++ b/command_test.go @@ -268,8 +268,8 @@ func TestParseAndRunShortOpts(t *testing.T) { func TestCommand_Run_DoesNotOverwriteErrorFromBefore(t *testing.T) { cmd := &Command{ Name: "bar", - Before: func(context.Context, *Command) error { - return fmt.Errorf("before error") + Before: func(context.Context, *Command) (context.Context, error) { + return nil, fmt.Errorf("before error") }, After: func(context.Context, *Command) error { return fmt.Errorf("after error") @@ -289,9 +289,9 @@ func TestCommand_Run_BeforeSavesMetadata(t *testing.T) { cmd := &Command{ Name: "bar", - Before: func(_ context.Context, cmd *Command) error { + Before: func(_ context.Context, cmd *Command) (context.Context, error) { cmd.Metadata["msg"] = "hello world" - return nil + return nil, nil }, Action: func(_ context.Context, cmd *Command) error { msg, ok := cmd.Metadata["msg"] @@ -1340,15 +1340,15 @@ func TestCommand_BeforeFunc(t *testing.T) { var err error cmd := &Command{ - Before: func(_ context.Context, cmd *Command) error { + Before: func(_ context.Context, cmd *Command) (context.Context, error) { counts.Total++ counts.Before = counts.Total s := cmd.String("opt") if s == "fail" { - return beforeError + return nil, beforeError } - return nil + return nil, nil }, Commands: []*Command{ { @@ -1411,10 +1411,10 @@ func TestCommand_BeforeAfterFuncShellCompletion(t *testing.T) { cmd := &Command{ EnableShellCompletion: true, - Before: func(context.Context, *Command) error { + Before: func(context.Context, *Command) (context.Context, error) { counts.Total++ counts.Before = counts.Total - return nil + return nil, nil }, After: func(context.Context, *Command) error { counts.Total++ @@ -1758,10 +1758,10 @@ func TestCommand_OrderOfOperations(t *testing.T) { Writer: io.Discard, } - beforeNoError := func(context.Context, *Command) error { + beforeNoError := func(context.Context, *Command) (context.Context, error) { counts.Total++ counts.Before = counts.Total - return nil + return nil, nil } cmd.Before = beforeNoError @@ -1838,10 +1838,10 @@ func TestCommand_OrderOfOperations(t *testing.T) { t.Run("before with error", func(t *testing.T) { cmd, counts := buildCmdCounts() - cmd.Before = func(context.Context, *Command) error { + cmd.Before = func(context.Context, *Command) (context.Context, error) { counts.Total++ counts.Before = counts.Total - return errors.New("hay Before") + return nil, errors.New("hay Before") } r := require.New(t) @@ -2213,7 +2213,7 @@ func TestCommand_Run_SubcommandDoesNotOverwriteErrorFromBefore(t *testing.T) { }, }, Name: "bar", - Before: func(context.Context, *Command) error { return fmt.Errorf("before error") }, + Before: func(context.Context, *Command) (context.Context, error) { return nil, fmt.Errorf("before error") }, After: func(context.Context, *Command) error { return fmt.Errorf("after error") }, }, }, diff --git a/funcs.go b/funcs.go index 9a7361d81c..e0d1e19c2b 100644 --- a/funcs.go +++ b/funcs.go @@ -8,7 +8,7 @@ type ShellCompleteFunc func(context.Context, *Command) // BeforeFunc is an action that executes prior to any subcommands being run once // the context is ready. If a non-nil error is returned, no subcommands are // run. -type BeforeFunc func(context.Context, *Command) error +type BeforeFunc func(context.Context, *Command) (context.Context, error) // AfterFunc is an action that executes after any subcommands are run and have // finished. The AfterFunc is run even if Action() panics. diff --git a/godoc-current.txt b/godoc-current.txt index 48d6ff7f0c..4d6a171142 100644 --- a/godoc-current.txt +++ b/godoc-current.txt @@ -247,7 +247,7 @@ func (a *ArgumentBase[T, C, VC]) Parse(s []string) ([]string, error) func (a *ArgumentBase[T, C, VC]) Usage() string -type BeforeFunc func(context.Context, *Command) error +type BeforeFunc func(context.Context, *Command) (context.Context, error) BeforeFunc is an action that executes prior to any subcommands being run once the context is ready. If a non-nil error is returned, no subcommands are run. diff --git a/testdata/godoc-v3.x.txt b/testdata/godoc-v3.x.txt index 48d6ff7f0c..4d6a171142 100644 --- a/testdata/godoc-v3.x.txt +++ b/testdata/godoc-v3.x.txt @@ -247,7 +247,7 @@ func (a *ArgumentBase[T, C, VC]) Parse(s []string) ([]string, error) func (a *ArgumentBase[T, C, VC]) Usage() string -type BeforeFunc func(context.Context, *Command) error +type BeforeFunc func(context.Context, *Command) (context.Context, error) BeforeFunc is an action that executes prior to any subcommands being run once the context is ready. If a non-nil error is returned, no subcommands are run. From 9aaed62ac4c580a5eaafe3e838c85a61085a5bde Mon Sep 17 00:00:00 2001 From: Naveen Gogineni Date: Sat, 26 Oct 2024 08:02:33 -0400 Subject: [PATCH 2/3] Add test --- command_test.go | 39 +++++++++++++++++++++++++++++++++++++-- 1 file changed, 37 insertions(+), 2 deletions(-) diff --git a/command_test.go b/command_test.go index 160b9b16d4..5946264abe 100644 --- a/command_test.go +++ b/command_test.go @@ -289,16 +289,17 @@ func TestCommand_Run_BeforeSavesMetadata(t *testing.T) { cmd := &Command{ Name: "bar", - Before: func(_ context.Context, cmd *Command) (context.Context, error) { + Before: func(ctx context.Context, cmd *Command) (context.Context, error) { cmd.Metadata["msg"] = "hello world" return nil, nil }, - Action: func(_ context.Context, cmd *Command) error { + Action: func(ctx context.Context, cmd *Command) error { msg, ok := cmd.Metadata["msg"] if !ok { return errors.New("msg not found") } receivedMsgFromAction = msg.(string) + return nil }, After: func(_ context.Context, cmd *Command) error { @@ -316,6 +317,40 @@ func TestCommand_Run_BeforeSavesMetadata(t *testing.T) { require.Equal(t, "hello world", receivedMsgFromAfter) } +func TestCommand_Run_BeforeReturnNewContext(t *testing.T) { + var receivedValFromAction, receivedValFromAfter string + type key string + + bkey := key("bkey") + + cmd := &Command{ + Name: "bar", + Before: func(ctx context.Context, cmd *Command) (context.Context, error) { + return context.WithValue(ctx, bkey, "bval"), nil + }, + Action: func(ctx context.Context, cmd *Command) error { + if val := ctx.Value(bkey); val == nil { + return errors.New("bkey value not found") + } else { + receivedValFromAction = val.(string) + } + return nil + }, + After: func(ctx context.Context, cmd *Command) error { + if val := ctx.Value(bkey); val == nil { + return errors.New("bkey value not found") + } else { + receivedValFromAfter = val.(string) + } + return nil + }, + } + + require.NoError(t, cmd.Run(buildTestContext(t), []string{"foo", "bar"})) + require.Equal(t, "bval", receivedValFromAfter) + require.Equal(t, "bval", receivedValFromAction) +} + func TestCommand_OnUsageError_hasCommandContext(t *testing.T) { cmd := &Command{ Name: "bar", From e069edf8bb81c0f2533eeebae209d108fc9b14d6 Mon Sep 17 00:00:00 2001 From: Naveen Gogineni Date: Wed, 30 Oct 2024 08:41:16 -0400 Subject: [PATCH 3/3] Update docs --- docs/v3/examples/full-api-example.md | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/docs/v3/examples/full-api-example.md b/docs/v3/examples/full-api-example.md index 6a626fdaa1..a468be857a 100644 --- a/docs/v3/examples/full-api-example.md +++ b/docs/v3/examples/full-api-example.md @@ -111,9 +111,9 @@ func main() { ShellComplete: func(ctx context.Context, cmd *cli.Command) { fmt.Fprintf(cmd.Root().Writer, "--better\n") }, - Before: func(ctx context.Context, cmd *cli.Command) error { + Before: func(ctx context.Context, cmd *cli.Command) (context.Context, error) { fmt.Fprintf(cmd.Root().Writer, "brace for impact\n") - return nil + return nil, nil }, After: func(ctx context.Context, cmd *cli.Command) error { fmt.Fprintf(cmd.Root().Writer, "did we lose anyone?\n")