Skip to content

Commit 20ef97b

Browse files
authored
Merge pull request #1975 from dearchap/dont_check_req_flags
Dont check required flags until command action is executed
2 parents 31c5c84 + 725b339 commit 20ef97b

File tree

2 files changed

+85
-6
lines changed

2 files changed

+85
-6
lines changed

command.go

Lines changed: 15 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -554,12 +554,6 @@ func (cmd *Command) Run(ctx context.Context, osArgs []string) (deferErr error) {
554554
}()
555555
}
556556

557-
if err := cmd.checkRequiredFlags(); err != nil {
558-
cmd.isInError = true
559-
_ = ShowSubcommandHelp(cmd)
560-
return err
561-
}
562-
563557
for _, grp := range cmd.MutuallyExclusiveFlags {
564558
if err := grp.check(cmd); err != nil {
565559
_ = ShowSubcommandHelp(cmd)
@@ -636,6 +630,12 @@ func (cmd *Command) Run(ctx context.Context, osArgs []string) (deferErr error) {
636630
if cmd.Action == nil {
637631
cmd.Action = helpCommandAction
638632
} else {
633+
if err := cmd.checkAllRequiredFlags(); err != nil {
634+
cmd.isInError = true
635+
_ = ShowSubcommandHelp(cmd)
636+
return err
637+
}
638+
639639
if err := cmd.checkPersistentRequiredFlags(); err != nil {
640640
cmd.isInError = true
641641
_ = ShowSubcommandHelp(cmd)
@@ -993,6 +993,15 @@ func (cmd *Command) checkRequiredFlag(f Flag) (bool, string) {
993993
return true, ""
994994
}
995995

996+
func (cmd *Command) checkAllRequiredFlags() requiredFlagsErr {
997+
if cmd.parent != nil {
998+
if err := cmd.parent.checkRequiredFlags(); err != nil {
999+
return err
1000+
}
1001+
}
1002+
return cmd.checkRequiredFlags()
1003+
}
1004+
9961005
func (cmd *Command) checkRequiredFlags() requiredFlagsErr {
9971006
tracef("checking for required flags (cmd=%[1]q)", cmd.Name)
9981007

command_test.go

Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2941,6 +2941,76 @@ func TestPersistentFlagIsSet(t *testing.T) {
29412941
require.True(t, resultIsSet)
29422942
}
29432943

2944+
func TestRequiredFlagDelayed(t *testing.T) {
2945+
sf := &StringFlag{
2946+
Name: "result",
2947+
Required: true,
2948+
}
2949+
2950+
expectedErr := &errRequiredFlags{
2951+
missingFlags: []string{sf.Name},
2952+
}
2953+
2954+
tests := []struct {
2955+
name string
2956+
args []string
2957+
errExpected error
2958+
}{
2959+
{
2960+
name: "leaf help",
2961+
args: []string{"root", "sub", "-h"},
2962+
errExpected: nil,
2963+
},
2964+
{
2965+
name: "leaf action",
2966+
args: []string{"root", "sub"},
2967+
errExpected: expectedErr,
2968+
},
2969+
{
2970+
name: "leaf flags set",
2971+
args: []string{"root", "sub", "--if", "10"},
2972+
errExpected: expectedErr,
2973+
},
2974+
{
2975+
name: "leaf invalid flags set",
2976+
args: []string{"root", "sub", "--xx"},
2977+
errExpected: expectedErr,
2978+
},
2979+
}
2980+
2981+
app := &Command{
2982+
Name: "root",
2983+
Flags: []Flag{
2984+
sf,
2985+
},
2986+
Commands: []*Command{
2987+
{
2988+
Name: "sub",
2989+
Flags: []Flag{
2990+
&IntFlag{
2991+
Name: "if",
2992+
Required: true,
2993+
},
2994+
},
2995+
Action: func(ctx context.Context, c *Command) error {
2996+
return nil
2997+
},
2998+
},
2999+
},
3000+
}
3001+
3002+
for _, test := range tests {
3003+
t.Run(test.name, func(t *testing.T) {
3004+
err := app.Run(context.Background(), test.args)
3005+
if test.errExpected == nil {
3006+
require.NoError(t, err)
3007+
} else {
3008+
require.ErrorAs(t, err, &test.errExpected)
3009+
}
3010+
})
3011+
}
3012+
}
3013+
29443014
func TestRequiredPersistentFlag(t *testing.T) {
29453015
app := &Command{
29463016
Name: "root",

0 commit comments

Comments
 (0)