Skip to content

Commit fc3b515

Browse files
authored
Merge pull request #1785 from dearchap/validation
Feature: Add support for validation functions
2 parents 4a9488f + 51cb2ef commit fc3b515

10 files changed

+215
-114
lines changed

command_test.go

Lines changed: 22 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -3068,27 +3068,6 @@ func TestPersistentFlag(t *testing.T) {
30683068
}
30693069

30703070
func TestFlagDuplicates(t *testing.T) {
3071-
cmd := &Command{
3072-
Flags: []Flag{
3073-
&StringFlag{
3074-
Name: "sflag",
3075-
OnlyOnce: true,
3076-
},
3077-
&IntSliceFlag{
3078-
Name: "isflag",
3079-
},
3080-
&FloatSliceFlag{
3081-
Name: "fsflag",
3082-
OnlyOnce: true,
3083-
},
3084-
&IntFlag{
3085-
Name: "iflag",
3086-
},
3087-
},
3088-
Action: func(ctx *Context) error {
3089-
return nil
3090-
},
3091-
}
30923071

30933072
tests := []struct {
30943073
name string
@@ -3117,6 +3096,28 @@ func TestFlagDuplicates(t *testing.T) {
31173096

31183097
for _, test := range tests {
31193098
t.Run(test.name, func(t *testing.T) {
3099+
cmd := &Command{
3100+
Flags: []Flag{
3101+
&StringFlag{
3102+
Name: "sflag",
3103+
OnlyOnce: true,
3104+
},
3105+
&IntSliceFlag{
3106+
Name: "isflag",
3107+
},
3108+
&FloatSliceFlag{
3109+
Name: "fsflag",
3110+
OnlyOnce: true,
3111+
},
3112+
&IntFlag{
3113+
Name: "iflag",
3114+
},
3115+
},
3116+
Action: func(ctx *Context) error {
3117+
return nil
3118+
},
3119+
}
3120+
31203121
err := cmd.Run(buildTestContext(t), test.args)
31213122
if test.errExpected && err == nil {
31223123
t.Error("expected error")

errors.go

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -199,3 +199,12 @@ func handleMultiError(multiErr MultiError) int {
199199
}
200200
return code
201201
}
202+
203+
type typeError[T any] struct {
204+
other any
205+
}
206+
207+
func (te *typeError[T]) Error() string {
208+
var t T
209+
return fmt.Sprintf("Expected type %T got instead %T", t, te.other)
210+
}

flag_impl.go

Lines changed: 58 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -13,47 +13,35 @@ type Value interface {
1313
flag.Getter
1414
}
1515

16-
// simple wrapper to intercept Value operations
17-
// to check for duplicates
18-
type valueWrapper struct {
19-
value Value
20-
count int
21-
onlyOnce bool
16+
type boolFlag interface {
17+
IsBoolFlag() bool
2218
}
2319

24-
func (v *valueWrapper) String() string {
25-
if v.value == nil {
26-
return ""
27-
}
28-
return v.value.String()
20+
type fnValue struct {
21+
fn func(string) error
22+
isBool bool
23+
v Value
2924
}
3025

31-
func (v *valueWrapper) Set(s string) error {
32-
if v.count == 1 && v.onlyOnce {
33-
return fmt.Errorf("cant duplicate this flag")
26+
func (f *fnValue) Get() any { return f.v.Get() }
27+
func (f *fnValue) Set(s string) error { return f.fn(s) }
28+
func (f *fnValue) String() string {
29+
if f.v == nil {
30+
return ""
3431
}
35-
v.count++
36-
return v.value.Set(s)
37-
}
38-
39-
func (v *valueWrapper) Get() any {
40-
return v.value.Get()
41-
}
42-
43-
func (v *valueWrapper) IsBoolFlag() bool {
44-
_, ok := v.value.(*boolValue)
45-
return ok
32+
return f.v.String()
4633
}
4734

48-
func (v *valueWrapper) Serialize() string {
49-
if s, ok := v.value.(Serializer); ok {
35+
func (f *fnValue) Serialize() string {
36+
if s, ok := f.v.(Serializer); ok {
5037
return s.Serialize()
5138
}
52-
return v.value.String()
39+
return f.v.String()
5340
}
5441

55-
func (v *valueWrapper) Count() int {
56-
if s, ok := v.value.(Countable); ok {
42+
func (f *fnValue) IsBoolFlag() bool { return f.isBool }
43+
func (f *fnValue) Count() int {
44+
if s, ok := f.v.(Countable); ok {
5745
return s.Count()
5846
}
5947
return 0
@@ -105,7 +93,10 @@ type FlagBase[T any, C any, VC ValueCreator[T, C]] struct {
10593

10694
OnlyOnce bool // whether this flag can be duplicated on the command line
10795

96+
Validator func(T) error // custom function to validate this flag value
97+
10898
// unexported fields for internal use
99+
count int // number of times the flag has been set
109100
hasBeenSet bool // whether the flag has been set from env or file
110101
applied bool // whether the flag has been applied to a flag set already
111102
creator VC // value creator for this flag type
@@ -160,15 +151,48 @@ func (f *FlagBase[T, C, V]) Apply(set *flag.FlagSet) error {
160151
} else {
161152
f.value = f.creator.Create(newVal, f.Destination, f.Config)
162153
}
154+
155+
// Validate the given default or values set from external sources as well
156+
if f.Validator != nil {
157+
if v, ok := f.value.Get().(T); !ok {
158+
return &typeError[T]{
159+
other: f.value.Get(),
160+
}
161+
} else if err := f.Validator(v); err != nil {
162+
return err
163+
}
164+
}
163165
}
164166

165-
vw := &valueWrapper{
166-
value: f.value,
167-
onlyOnce: f.OnlyOnce,
167+
isBool := false
168+
if b, ok := f.value.(boolFlag); ok && b.IsBoolFlag() {
169+
isBool = true
168170
}
169171

170172
for _, name := range f.Names() {
171-
set.Var(vw, name, f.Usage)
173+
set.Var(&fnValue{
174+
fn: func(val string) error {
175+
if f.count == 1 && f.OnlyOnce {
176+
return fmt.Errorf("cant duplicate this flag")
177+
}
178+
f.count++
179+
if err := f.value.Set(val); err != nil {
180+
return err
181+
}
182+
if f.Validator != nil {
183+
if v, ok := f.value.Get().(T); !ok {
184+
return &typeError[T]{
185+
other: f.value.Get(),
186+
}
187+
} else if err := f.Validator(v); err != nil {
188+
return err
189+
}
190+
}
191+
return nil
192+
},
193+
isBool: isBool,
194+
v: f.value,
195+
}, name, f.Usage)
172196
}
173197

174198
f.applied = true

flag_int_slice.go

Lines changed: 2 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,5 @@
11
package cli
22

3-
import "flag"
4-
53
type IntSlice = SliceBase[int64, IntegerConfig, intValue]
64
type IntSliceFlag = FlagBase[[]int64, IntegerConfig, IntSlice]
75

@@ -10,18 +8,9 @@ var NewIntSlice = NewSliceBase[int64, IntegerConfig, intValue]
108
// IntSlice looks up the value of a local IntSliceFlag, returns
119
// nil if not found
1210
func (cCtx *Context) IntSlice(name string) []int64 {
13-
if fs := cCtx.lookupFlagSet(name); fs != nil {
14-
return lookupIntSlice(name, fs)
11+
if v, ok := cCtx.Value(name).([]int64); ok {
12+
return v
1513
}
16-
return nil
17-
}
1814

19-
func lookupIntSlice(name string, set *flag.FlagSet) []int64 {
20-
f := set.Lookup(name)
21-
if f != nil {
22-
if slice, ok := f.Value.(flag.Getter).Get().([]int64); ok {
23-
return slice
24-
}
25-
}
2615
return nil
2716
}

flag_string_map.go

Lines changed: 2 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,5 @@
11
package cli
22

3-
import "flag"
4-
53
type StringMap = MapBase[string, StringConfig, stringValue]
64
type StringMapFlag = FlagBase[map[string]string, StringConfig, StringMap]
75

@@ -10,18 +8,8 @@ var NewStringMap = NewMapBase[string, StringConfig, stringValue]
108
// StringMap looks up the value of a local StringMapFlag, returns
119
// nil if not found
1210
func (cCtx *Context) StringMap(name string) map[string]string {
13-
if fs := cCtx.lookupFlagSet(name); fs != nil {
14-
return lookupStringMap(name, fs)
15-
}
16-
return nil
17-
}
18-
19-
func lookupStringMap(name string, set *flag.FlagSet) map[string]string {
20-
f := set.Lookup(name)
21-
if f != nil {
22-
if mapping, ok := f.Value.(flag.Getter).Get().(map[string]string); ok {
23-
return mapping
24-
}
11+
if v, ok := cCtx.Value(name).(map[string]string); ok {
12+
return v
2513
}
2614
return nil
2715
}

flag_string_slice.go

Lines changed: 2 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,5 @@
11
package cli
22

3-
import (
4-
"flag"
5-
)
6-
73
type StringSlice = SliceBase[string, StringConfig, stringValue]
84
type StringSliceFlag = FlagBase[[]string, StringConfig, StringSlice]
95

@@ -12,18 +8,8 @@ var NewStringSlice = NewSliceBase[string, StringConfig, stringValue]
128
// StringSlice looks up the value of a local StringSliceFlag, returns
139
// nil if not found
1410
func (cCtx *Context) StringSlice(name string) []string {
15-
if fs := cCtx.lookupFlagSet(name); fs != nil {
16-
return lookupStringSlice(name, fs)
17-
}
18-
return nil
19-
}
20-
21-
func lookupStringSlice(name string, set *flag.FlagSet) []string {
22-
f := set.Lookup(name)
23-
if f != nil {
24-
if slice, ok := f.Value.(flag.Getter).Get().([]string); ok {
25-
return slice
26-
}
11+
if v, ok := cCtx.Value(name).([]string); ok {
12+
return v
2713
}
2814
return nil
2915
}

flag_uint_slice.go

Lines changed: 2 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,5 @@
11
package cli
22

3-
import (
4-
"flag"
5-
)
6-
73
type UintSlice = SliceBase[uint64, IntegerConfig, uintValue]
84
type UintSliceFlag = FlagBase[[]uint64, IntegerConfig, UintSlice]
95

@@ -12,18 +8,8 @@ var NewUintSlice = NewSliceBase[uint64, IntegerConfig, uintValue]
128
// UintSlice looks up the value of a local UintSliceFlag, returns
139
// nil if not found
1410
func (cCtx *Context) UintSlice(name string) []uint64 {
15-
if fs := cCtx.lookupFlagSet(name); fs != nil {
16-
return lookupUintSlice(name, fs)
17-
}
18-
return nil
19-
}
20-
21-
func lookupUintSlice(name string, set *flag.FlagSet) []uint64 {
22-
f := set.Lookup(name)
23-
if f != nil {
24-
if slice, ok := f.Value.(flag.Getter).Get().([]uint64); ok {
25-
return slice
26-
}
11+
if v, ok := cCtx.Value(name).([]uint64); ok {
12+
return v
2713
}
2814
return nil
2915
}

0 commit comments

Comments
 (0)