Skip to content

Commit d6eaf9a

Browse files
committed
Fix:(issue_1834) Add check for persistent required flags
1 parent 2b97d2e commit d6eaf9a

File tree

2 files changed

+104
-21
lines changed

2 files changed

+104
-21
lines changed

command.go

Lines changed: 62 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -522,18 +522,26 @@ func (cmd *Command) Run(ctx context.Context, osArgs []string) (deferErr error) {
522522

523523
if cmd.Action == nil {
524524
cmd.Action = helpCommandAction
525-
} else if len(cmd.Arguments) > 0 {
526-
rargs := cmd.Args().Slice()
527-
tracef("calling argparse with %[1]v", rargs)
528-
for _, arg := range cmd.Arguments {
529-
var err error
530-
rargs, err = arg.Parse(rargs)
531-
if err != nil {
532-
tracef("calling with %[1]v (cmd=%[2]q)", err, cmd.Name)
533-
return err
525+
} else {
526+
if err := cmd.checkPersistentRequiredFlags(); err != nil {
527+
cmd.isInError = true
528+
_ = ShowSubcommandHelp(cmd)
529+
return err
530+
}
531+
532+
if len(cmd.Arguments) > 0 {
533+
rargs := cmd.Args().Slice()
534+
tracef("calling argparse with %[1]v", rargs)
535+
for _, arg := range cmd.Arguments {
536+
var err error
537+
rargs, err = arg.Parse(rargs)
538+
if err != nil {
539+
tracef("calling with %[1]v (cmd=%[2]q)", err, cmd.Name)
540+
return err
541+
}
534542
}
543+
cmd.parsedArgs = &stringSliceArgs{v: rargs}
535544
}
536-
cmd.parsedArgs = &stringSliceArgs{v: rargs}
537545
}
538546

539547
if err := cmd.Action(ctx, cmd); err != nil {
@@ -840,26 +848,59 @@ func (cmd *Command) lookupFlagSet(name string) *flag.FlagSet {
840848
return nil
841849
}
842850

851+
func (cmd *Command) checkRequiredFlag(f Flag) (bool, string) {
852+
if rf, ok := f.(RequiredFlag); ok && rf.IsRequired() {
853+
flagPresent := false
854+
flagName := ""
855+
856+
for _, key := range f.Names() {
857+
flagName = key
858+
859+
if cmd.IsSet(strings.TrimSpace(key)) {
860+
flagPresent = true
861+
}
862+
}
863+
864+
if !flagPresent && flagName != "" {
865+
return false, flagName
866+
}
867+
}
868+
return true, ""
869+
}
870+
843871
func (cmd *Command) checkRequiredFlags() requiredFlagsErr {
844872
tracef("checking for required flags (cmd=%[1]q)", cmd.Name)
845873

846874
missingFlags := []string{}
847875

848876
for _, f := range cmd.Flags {
849-
if rf, ok := f.(RequiredFlag); ok && rf.IsRequired() {
850-
flagPresent := false
851-
flagName := ""
877+
if pf, ok := f.(PersistentFlag); !ok || !pf.IsPersistent() {
878+
if ok, name := cmd.checkRequiredFlag(f); !ok {
879+
missingFlags = append(missingFlags, name)
880+
}
881+
}
882+
}
852883

853-
for _, key := range f.Names() {
854-
flagName = key
884+
if len(missingFlags) != 0 {
885+
tracef("found missing required flags %[1]q (cmd=%[2]q)", missingFlags, cmd.Name)
855886

856-
if cmd.IsSet(strings.TrimSpace(key)) {
857-
flagPresent = true
858-
}
859-
}
887+
return &errRequiredFlags{missingFlags: missingFlags}
888+
}
889+
890+
tracef("all required flags set (cmd=%[1]q)", cmd.Name)
891+
892+
return nil
893+
}
894+
895+
func (cmd *Command) checkPersistentRequiredFlags() requiredFlagsErr {
896+
tracef("checking for required flags (cmd=%[1]q)", cmd.Name)
897+
898+
missingFlags := []string{}
860899

861-
if !flagPresent && flagName != "" {
862-
missingFlags = append(missingFlags, flagName)
900+
for _, f := range cmd.appliedFlags {
901+
if pf, ok := f.(PersistentFlag); ok && pf.IsPersistent() {
902+
if ok, name := cmd.checkRequiredFlag(f); !ok {
903+
missingFlags = append(missingFlags, name)
863904
}
864905
}
865906
}

command_test.go

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2926,6 +2926,7 @@ func TestFlagAction(t *testing.T) {
29262926
func TestPersistentFlag(t *testing.T) {
29272927
var topInt, topPersistentInt, subCommandInt, appOverrideInt int64
29282928
var appFlag string
2929+
var appRequiredFlag string
29292930
var appOverrideCmdInt int64
29302931
var appSliceFloat64 []float64
29312932
var persistentCommandSliceInt []int64
@@ -2957,6 +2958,12 @@ func TestPersistentFlag(t *testing.T) {
29572958
Persistent: true,
29582959
Destination: &appOverrideInt,
29592960
},
2961+
&StringFlag{
2962+
Name: "persistentRequiredCommandFlag",
2963+
Persistent: true,
2964+
Required: true,
2965+
Destination: &appRequiredFlag,
2966+
},
29602967
},
29612968
Commands: []*Command{
29622969
{
@@ -3005,6 +3012,7 @@ func TestPersistentFlag(t *testing.T) {
30053012
"--persistentCommandSliceFlag", "102",
30063013
"--persistentCommandFloatSliceFlag", "102.455",
30073014
"--paof", "105",
3015+
"--persistentRequiredCommandFlag", "hellor",
30083016
"subcmd",
30093017
"--cmdPersistentFlag", "20",
30103018
"--cmdFlag", "11",
@@ -3021,6 +3029,10 @@ func TestPersistentFlag(t *testing.T) {
30213029
t.Errorf("Expected 'bar' got %s", appFlag)
30223030
}
30233031

3032+
if appRequiredFlag != "hellor" {
3033+
t.Errorf("Expected 'hellor' got %s", appRequiredFlag)
3034+
}
3035+
30243036
if topInt != 12 {
30253037
t.Errorf("Expected 12 got %d", topInt)
30263038
}
@@ -3096,6 +3108,36 @@ func TestPersistentFlagIsSet(t *testing.T) {
30963108
r.True(resultIsSet)
30973109
}
30983110

3111+
func TestRequiredPersistentFlag(t *testing.T) {
3112+
3113+
app := &Command{
3114+
Name: "root",
3115+
Flags: []Flag{
3116+
&StringFlag{
3117+
Name: "result",
3118+
Persistent: true,
3119+
Required: true,
3120+
},
3121+
},
3122+
Commands: []*Command{
3123+
{
3124+
Name: "sub",
3125+
Action: func(ctx context.Context, c *Command) error {
3126+
return nil
3127+
},
3128+
},
3129+
},
3130+
}
3131+
3132+
r := require.New(t)
3133+
3134+
err := app.Run(context.Background(), []string{"root", "sub"})
3135+
r.Error(err)
3136+
3137+
err = app.Run(context.Background(), []string{"root", "sub", "--result", "after"})
3138+
r.NoError(err)
3139+
}
3140+
30993141
func TestFlagDuplicates(t *testing.T) {
31003142
tests := []struct {
31013143
name string

0 commit comments

Comments
 (0)