Skip to content

Commit c81c46a

Browse files
authored
Add 'one required flag' group (#1952)
1 parent dcb405a commit c81c46a

File tree

4 files changed

+228
-12
lines changed

4 files changed

+228
-12
lines changed

completions_test.go

Lines changed: 98 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2830,6 +2830,104 @@ func TestCompletionForGroupedFlags(t *testing.T) {
28302830
}
28312831
}
28322832

2833+
func TestCompletionForOneRequiredGroupFlags(t *testing.T) {
2834+
getCmd := func() *Command {
2835+
rootCmd := &Command{
2836+
Use: "root",
2837+
Run: emptyRun,
2838+
}
2839+
childCmd := &Command{
2840+
Use: "child",
2841+
ValidArgsFunction: func(cmd *Command, args []string, toComplete string) ([]string, ShellCompDirective) {
2842+
return []string{"subArg"}, ShellCompDirectiveNoFileComp
2843+
},
2844+
Run: emptyRun,
2845+
}
2846+
rootCmd.AddCommand(childCmd)
2847+
2848+
rootCmd.PersistentFlags().Int("ingroup1", -1, "ingroup1")
2849+
rootCmd.PersistentFlags().String("ingroup2", "", "ingroup2")
2850+
2851+
childCmd.Flags().Bool("ingroup3", false, "ingroup3")
2852+
childCmd.Flags().Bool("nogroup", false, "nogroup")
2853+
2854+
// Add flags to a group
2855+
childCmd.MarkFlagsOneRequired("ingroup1", "ingroup2", "ingroup3")
2856+
2857+
return rootCmd
2858+
}
2859+
2860+
// Each test case uses a unique command from the function above.
2861+
testcases := []struct {
2862+
desc string
2863+
args []string
2864+
expectedOutput string
2865+
}{
2866+
{
2867+
desc: "flags in group suggested without - prefix",
2868+
args: []string{"child", ""},
2869+
expectedOutput: strings.Join([]string{
2870+
"--ingroup1",
2871+
"--ingroup2",
2872+
"--ingroup3",
2873+
"subArg",
2874+
":4",
2875+
"Completion ended with directive: ShellCompDirectiveNoFileComp", ""}, "\n"),
2876+
},
2877+
{
2878+
desc: "flags in group suggested with - prefix",
2879+
args: []string{"child", "-"},
2880+
expectedOutput: strings.Join([]string{
2881+
"--ingroup1",
2882+
"--ingroup2",
2883+
"--ingroup3",
2884+
":4",
2885+
"Completion ended with directive: ShellCompDirectiveNoFileComp", ""}, "\n"),
2886+
},
2887+
{
2888+
desc: "when any flag in group present, other flags in group not suggested without - prefix",
2889+
args: []string{"child", "--ingroup2", "value", ""},
2890+
expectedOutput: strings.Join([]string{
2891+
"subArg",
2892+
":4",
2893+
"Completion ended with directive: ShellCompDirectiveNoFileComp", ""}, "\n"),
2894+
},
2895+
{
2896+
desc: "when all flags in group present, flags not suggested without - prefix",
2897+
args: []string{"child", "--ingroup1", "8", "--ingroup2", "value2", "--ingroup3", ""},
2898+
expectedOutput: strings.Join([]string{
2899+
"subArg",
2900+
":4",
2901+
"Completion ended with directive: ShellCompDirectiveNoFileComp", ""}, "\n"),
2902+
},
2903+
{
2904+
desc: "group ignored if some flags not applicable",
2905+
args: []string{"--ingroup2", "value", ""},
2906+
expectedOutput: strings.Join([]string{
2907+
"child",
2908+
"completion",
2909+
"help",
2910+
":4",
2911+
"Completion ended with directive: ShellCompDirectiveNoFileComp", ""}, "\n"),
2912+
},
2913+
}
2914+
2915+
for _, tc := range testcases {
2916+
t.Run(tc.desc, func(t *testing.T) {
2917+
c := getCmd()
2918+
args := []string{ShellCompNoDescRequestCmd}
2919+
args = append(args, tc.args...)
2920+
output, err := executeCommand(c, args...)
2921+
switch {
2922+
case err == nil && output != tc.expectedOutput:
2923+
t.Errorf("expected: %q, got: %q", tc.expectedOutput, output)
2924+
case err != nil:
2925+
t.Errorf("Unexpected error %q", err)
2926+
}
2927+
})
2928+
}
2929+
}
2930+
28332931
func TestCompletionForMutuallyExclusiveFlags(t *testing.T) {
28342932
getCmd := func() *Command {
28352933
rootCmd := &Command{

flag_groups.go

Lines changed: 67 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ import (
2424

2525
const (
2626
requiredAsGroup = "cobra_annotation_required_if_others_set"
27+
oneRequired = "cobra_annotation_one_required"
2728
mutuallyExclusive = "cobra_annotation_mutually_exclusive"
2829
)
2930

@@ -43,6 +44,22 @@ func (c *Command) MarkFlagsRequiredTogether(flagNames ...string) {
4344
}
4445
}
4546

47+
// MarkFlagsOneRequired marks the given flags with annotations so that Cobra errors
48+
// if the command is invoked without at least one flag from the given set of flags.
49+
func (c *Command) MarkFlagsOneRequired(flagNames ...string) {
50+
c.mergePersistentFlags()
51+
for _, v := range flagNames {
52+
f := c.Flags().Lookup(v)
53+
if f == nil {
54+
panic(fmt.Sprintf("Failed to find flag %q and mark it as being in a one-required flag group", v))
55+
}
56+
if err := c.Flags().SetAnnotation(v, oneRequired, append(f.Annotations[oneRequired], strings.Join(flagNames, " "))); err != nil {
57+
// Only errs if the flag isn't found.
58+
panic(err)
59+
}
60+
}
61+
}
62+
4663
// MarkFlagsMutuallyExclusive marks the given flags with annotations so that Cobra errors
4764
// if the command is invoked with more than one flag from the given set of flags.
4865
func (c *Command) MarkFlagsMutuallyExclusive(flagNames ...string) {
@@ -59,7 +76,7 @@ func (c *Command) MarkFlagsMutuallyExclusive(flagNames ...string) {
5976
}
6077
}
6178

62-
// ValidateFlagGroups validates the mutuallyExclusive/requiredAsGroup logic and returns the
79+
// ValidateFlagGroups validates the mutuallyExclusive/oneRequired/requiredAsGroup logic and returns the
6380
// first error encountered.
6481
func (c *Command) ValidateFlagGroups() error {
6582
if c.DisableFlagParsing {
@@ -71,15 +88,20 @@ func (c *Command) ValidateFlagGroups() error {
7188
// groupStatus format is the list of flags as a unique ID,
7289
// then a map of each flag name and whether it is set or not.
7390
groupStatus := map[string]map[string]bool{}
91+
oneRequiredGroupStatus := map[string]map[string]bool{}
7492
mutuallyExclusiveGroupStatus := map[string]map[string]bool{}
7593
flags.VisitAll(func(pflag *flag.Flag) {
7694
processFlagForGroupAnnotation(flags, pflag, requiredAsGroup, groupStatus)
95+
processFlagForGroupAnnotation(flags, pflag, oneRequired, oneRequiredGroupStatus)
7796
processFlagForGroupAnnotation(flags, pflag, mutuallyExclusive, mutuallyExclusiveGroupStatus)
7897
})
7998

8099
if err := validateRequiredFlagGroups(groupStatus); err != nil {
81100
return err
82101
}
102+
if err := validateOneRequiredFlagGroups(oneRequiredGroupStatus); err != nil {
103+
return err
104+
}
83105
if err := validateExclusiveFlagGroups(mutuallyExclusiveGroupStatus); err != nil {
84106
return err
85107
}
@@ -142,6 +164,27 @@ func validateRequiredFlagGroups(data map[string]map[string]bool) error {
142164
return nil
143165
}
144166

167+
func validateOneRequiredFlagGroups(data map[string]map[string]bool) error {
168+
keys := sortedKeys(data)
169+
for _, flagList := range keys {
170+
flagnameAndStatus := data[flagList]
171+
var set []string
172+
for flagname, isSet := range flagnameAndStatus {
173+
if isSet {
174+
set = append(set, flagname)
175+
}
176+
}
177+
if len(set) >= 1 {
178+
continue
179+
}
180+
181+
// Sort values, so they can be tested/scripted against consistently.
182+
sort.Strings(set)
183+
return fmt.Errorf("at least one of the flags in the group [%v] is required", flagList)
184+
}
185+
return nil
186+
}
187+
145188
func validateExclusiveFlagGroups(data map[string]map[string]bool) error {
146189
keys := sortedKeys(data)
147190
for _, flagList := range keys {
@@ -176,6 +219,7 @@ func sortedKeys(m map[string]map[string]bool) []string {
176219

177220
// enforceFlagGroupsForCompletion will do the following:
178221
// - when a flag in a group is present, other flags in the group will be marked required
222+
// - when none of the flags in a one-required group are present, all flags in the group will be marked required
179223
// - when a flag in a mutually exclusive group is present, other flags in the group will be marked as hidden
180224
// This allows the standard completion logic to behave appropriately for flag groups
181225
func (c *Command) enforceFlagGroupsForCompletion() {
@@ -185,9 +229,11 @@ func (c *Command) enforceFlagGroupsForCompletion() {
185229

186230
flags := c.Flags()
187231
groupStatus := map[string]map[string]bool{}
232+
oneRequiredGroupStatus := map[string]map[string]bool{}
188233
mutuallyExclusiveGroupStatus := map[string]map[string]bool{}
189234
c.Flags().VisitAll(func(pflag *flag.Flag) {
190235
processFlagForGroupAnnotation(flags, pflag, requiredAsGroup, groupStatus)
236+
processFlagForGroupAnnotation(flags, pflag, oneRequired, oneRequiredGroupStatus)
191237
processFlagForGroupAnnotation(flags, pflag, mutuallyExclusive, mutuallyExclusiveGroupStatus)
192238
})
193239

@@ -204,6 +250,26 @@ func (c *Command) enforceFlagGroupsForCompletion() {
204250
}
205251
}
206252

253+
// If none of the flags of a one-required group are present, we make all the flags
254+
// of that group required so that the shell completion suggests them automatically
255+
for flagList, flagnameAndStatus := range oneRequiredGroupStatus {
256+
set := 0
257+
258+
for _, isSet := range flagnameAndStatus {
259+
if isSet {
260+
set++
261+
}
262+
}
263+
264+
// None of the flags of the group are set, mark all flags in the group
265+
// as required
266+
if set == 0 {
267+
for _, fName := range strings.Split(flagList, " ") {
268+
_ = c.MarkFlagRequired(fName)
269+
}
270+
}
271+
}
272+
207273
// If a flag that is mutually exclusive to others is present, we hide the other
208274
// flags of that group so the shell completion does not suggest them
209275
for flagList, flagnameAndStatus := range mutuallyExclusiveGroupStatus {

flag_groups_test.go

Lines changed: 53 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -43,13 +43,15 @@ func TestValidateFlagGroups(t *testing.T) {
4343

4444
// Each test case uses a unique command from the function above.
4545
testcases := []struct {
46-
desc string
47-
flagGroupsRequired []string
48-
flagGroupsExclusive []string
49-
subCmdFlagGroupsRequired []string
50-
subCmdFlagGroupsExclusive []string
51-
args []string
52-
expectErr string
46+
desc string
47+
flagGroupsRequired []string
48+
flagGroupsOneRequired []string
49+
flagGroupsExclusive []string
50+
subCmdFlagGroupsRequired []string
51+
subCmdFlagGroupsOneRequired []string
52+
subCmdFlagGroupsExclusive []string
53+
args []string
54+
expectErr string
5355
}{
5456
{
5557
desc: "No flags no problem",
@@ -62,6 +64,11 @@ func TestValidateFlagGroups(t *testing.T) {
6264
flagGroupsRequired: []string{"a b c"},
6365
args: []string{"--a=foo"},
6466
expectErr: "if any flags in the group [a b c] are set they must all be set; missing [b c]",
67+
}, {
68+
desc: "One-required flag group not satisfied",
69+
flagGroupsOneRequired: []string{"a b"},
70+
args: []string{"--c=foo"},
71+
expectErr: "at least one of the flags in the group [a b] is required",
6572
}, {
6673
desc: "Exclusive flag group not satisfied",
6774
flagGroupsExclusive: []string{"a b c"},
@@ -72,6 +79,11 @@ func TestValidateFlagGroups(t *testing.T) {
7279
flagGroupsRequired: []string{"a b c", "a d"},
7380
args: []string{"--c=foo", "--d=foo"},
7481
expectErr: `if any flags in the group [a b c] are set they must all be set; missing [a b]`,
82+
}, {
83+
desc: "Multiple one-required flag group not satisfied returns first error",
84+
flagGroupsOneRequired: []string{"a b", "d e"},
85+
args: []string{"--c=foo", "--f=foo"},
86+
expectErr: `at least one of the flags in the group [a b] is required`,
7587
}, {
7688
desc: "Multiple exclusive flag group not satisfied returns first error",
7789
flagGroupsExclusive: []string{"a b c", "a d"},
@@ -82,32 +94,57 @@ func TestValidateFlagGroups(t *testing.T) {
8294
flagGroupsRequired: []string{"a d", "a b", "a c"},
8395
args: []string{"--a=foo"},
8496
expectErr: `if any flags in the group [a b] are set they must all be set; missing [b]`,
97+
}, {
98+
desc: "Validation of one-required groups occurs on groups in sorted order",
99+
flagGroupsOneRequired: []string{"d e", "a b", "f g"},
100+
args: []string{"--c=foo"},
101+
expectErr: `at least one of the flags in the group [a b] is required`,
85102
}, {
86103
desc: "Validation of exclusive groups occurs on groups in sorted order",
87104
flagGroupsExclusive: []string{"a d", "a b", "a c"},
88105
args: []string{"--a=foo", "--b=foo", "--c=foo"},
89106
expectErr: `if any flags in the group [a b] are set none of the others can be; [a b] were all set`,
90107
}, {
91-
desc: "Persistent flags utilize both features and can fail required groups",
108+
desc: "Persistent flags utilize required and exclusive groups and can fail required groups",
92109
flagGroupsRequired: []string{"a e", "e f"},
93110
flagGroupsExclusive: []string{"f g"},
94111
args: []string{"--a=foo", "--f=foo", "--g=foo"},
95112
expectErr: `if any flags in the group [a e] are set they must all be set; missing [e]`,
96113
}, {
97-
desc: "Persistent flags utilize both features and can fail mutually exclusive groups",
114+
desc: "Persistent flags utilize one-required and exclusive groups and can fail one-required groups",
115+
flagGroupsOneRequired: []string{"a b", "e f"},
116+
flagGroupsExclusive: []string{"e f"},
117+
args: []string{"--e=foo"},
118+
expectErr: `at least one of the flags in the group [a b] is required`,
119+
}, {
120+
desc: "Persistent flags utilize required and exclusive groups and can fail mutually exclusive groups",
98121
flagGroupsRequired: []string{"a e", "e f"},
99122
flagGroupsExclusive: []string{"f g"},
100123
args: []string{"--a=foo", "--e=foo", "--f=foo", "--g=foo"},
101124
expectErr: `if any flags in the group [f g] are set none of the others can be; [f g] were all set`,
102125
}, {
103-
desc: "Persistent flags utilize both features and can pass",
126+
desc: "Persistent flags utilize required and exclusive groups and can pass",
104127
flagGroupsRequired: []string{"a e", "e f"},
105128
flagGroupsExclusive: []string{"f g"},
106129
args: []string{"--a=foo", "--e=foo", "--f=foo"},
130+
}, {
131+
desc: "Persistent flags utilize one-required and exclusive groups and can pass",
132+
flagGroupsOneRequired: []string{"a e", "e f"},
133+
flagGroupsExclusive: []string{"f g"},
134+
args: []string{"--a=foo", "--e=foo", "--f=foo"},
107135
}, {
108136
desc: "Subcmds can use required groups using inherited flags",
109137
subCmdFlagGroupsRequired: []string{"e subonly"},
110138
args: []string{"subcmd", "--e=foo", "--subonly=foo"},
139+
}, {
140+
desc: "Subcmds can use one-required groups using inherited flags",
141+
subCmdFlagGroupsOneRequired: []string{"e subonly"},
142+
args: []string{"subcmd", "--e=foo", "--subonly=foo"},
143+
}, {
144+
desc: "Subcmds can use one-required groups using inherited flags and fail one-required groups",
145+
subCmdFlagGroupsOneRequired: []string{"e subonly"},
146+
args: []string{"subcmd"},
147+
expectErr: "at least one of the flags in the group [e subonly] is required",
111148
}, {
112149
desc: "Subcmds can use exclusive groups using inherited flags",
113150
subCmdFlagGroupsExclusive: []string{"e subonly"},
@@ -130,12 +167,18 @@ func TestValidateFlagGroups(t *testing.T) {
130167
for _, flagGroup := range tc.flagGroupsRequired {
131168
c.MarkFlagsRequiredTogether(strings.Split(flagGroup, " ")...)
132169
}
170+
for _, flagGroup := range tc.flagGroupsOneRequired {
171+
c.MarkFlagsOneRequired(strings.Split(flagGroup, " ")...)
172+
}
133173
for _, flagGroup := range tc.flagGroupsExclusive {
134174
c.MarkFlagsMutuallyExclusive(strings.Split(flagGroup, " ")...)
135175
}
136176
for _, flagGroup := range tc.subCmdFlagGroupsRequired {
137177
sub.MarkFlagsRequiredTogether(strings.Split(flagGroup, " ")...)
138178
}
179+
for _, flagGroup := range tc.subCmdFlagGroupsOneRequired {
180+
sub.MarkFlagsOneRequired(strings.Split(flagGroup, " ")...)
181+
}
139182
for _, flagGroup := range tc.subCmdFlagGroupsExclusive {
140183
sub.MarkFlagsMutuallyExclusive(strings.Split(flagGroup, " ")...)
141184
}

site/content/user_guide.md

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -349,7 +349,16 @@ rootCmd.Flags().BoolVar(&ofYaml, "yaml", false, "Output in YAML")
349349
rootCmd.MarkFlagsMutuallyExclusive("json", "yaml")
350350
```
351351

352-
In both of these cases:
352+
If you want to require at least one flag from a group to be present, you can use `MarkFlagsOneRequired`.
353+
This can be combined with `MarkFlagsMutuallyExclusive` to enforce exactly one flag from a given group:
354+
```go
355+
rootCmd.Flags().BoolVar(&ofJson, "json", false, "Output in JSON")
356+
rootCmd.Flags().BoolVar(&ofYaml, "yaml", false, "Output in YAML")
357+
rootCmd.MarkFlagsOneRequired("json", "yaml")
358+
rootCmd.MarkFlagsMutuallyExclusive("json", "yaml")
359+
```
360+
361+
In these cases:
353362
- both local and persistent flags can be used
354363
- **NOTE:** the group is only enforced on commands where every flag is defined
355364
- a flag may appear in multiple groups

0 commit comments

Comments
 (0)