Skip to content

Commit 6fc1d08

Browse files
committed
Add new net.IPNet flag type
1 parent 30bf08c commit 6fc1d08

File tree

2 files changed

+170
-0
lines changed

2 files changed

+170
-0
lines changed

ipnet.go

Lines changed: 100 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,100 @@
1+
package pflag
2+
3+
import (
4+
"fmt"
5+
"net"
6+
"strings"
7+
)
8+
9+
// IPNet adapts net.IPNet for use as a flag.
10+
type IPNetValue net.IPNet
11+
12+
func (ipnet IPNetValue) String() string {
13+
n := net.IPNet(ipnet)
14+
return n.String()
15+
}
16+
17+
func (ipnet *IPNetValue) Set(value string) error {
18+
_, n, err := net.ParseCIDR(strings.TrimSpace(value))
19+
if err != nil {
20+
return err
21+
}
22+
*ipnet = IPNetValue(*n)
23+
return nil
24+
}
25+
26+
func (*IPNetValue) Type() string {
27+
return "ipNet"
28+
}
29+
30+
var _ = strings.TrimSpace
31+
32+
func newIPNetValue(val net.IPNet, p *net.IPNet) *IPNetValue {
33+
*p = val
34+
return (*IPNetValue)(p)
35+
}
36+
37+
func ipNetConv(sval string) (interface{}, error) {
38+
_, n, err := net.ParseCIDR(strings.TrimSpace(sval))
39+
if err == nil {
40+
return *n, nil
41+
}
42+
return nil, fmt.Errorf("invalid string being converted to IPNet: %s", sval)
43+
}
44+
45+
// GetIPNet return the net.IPNet value of a flag with the given name
46+
func (f *FlagSet) GetIPNet(name string) (net.IPNet, error) {
47+
val, err := f.getFlagType(name, "ipNet", ipNetConv)
48+
if err != nil {
49+
return net.IPNet{}, err
50+
}
51+
return val.(net.IPNet), nil
52+
}
53+
54+
// IPNetVar defines an net.IPNet flag with specified name, default value, and usage string.
55+
// The argument p points to an net.IPNet variable in which to store the value of the flag.
56+
func (f *FlagSet) IPNetVar(p *net.IPNet, name string, value net.IPNet, usage string) {
57+
f.VarP(newIPNetValue(value, p), name, "", usage)
58+
}
59+
60+
// Like IPNetVar, but accepts a shorthand letter that can be used after a single dash.
61+
func (f *FlagSet) IPNetVarP(p *net.IPNet, name, shorthand string, value net.IPNet, usage string) {
62+
f.VarP(newIPNetValue(value, p), name, shorthand, usage)
63+
}
64+
65+
// IPNetVar defines an net.IPNet flag with specified name, default value, and usage string.
66+
// The argument p points to an net.IPNet variable in which to store the value of the flag.
67+
func IPNetVar(p *net.IPNet, name string, value net.IPNet, usage string) {
68+
CommandLine.VarP(newIPNetValue(value, p), name, "", usage)
69+
}
70+
71+
// Like IPNetVar, but accepts a shorthand letter that can be used after a single dash.
72+
func IPNetVarP(p *net.IPNet, name, shorthand string, value net.IPNet, usage string) {
73+
CommandLine.VarP(newIPNetValue(value, p), name, shorthand, usage)
74+
}
75+
76+
// IPNet defines an net.IPNet flag with specified name, default value, and usage string.
77+
// The return value is the address of an net.IPNet variable that stores the value of the flag.
78+
func (f *FlagSet) IPNet(name string, value net.IPNet, usage string) *net.IPNet {
79+
p := new(net.IPNet)
80+
f.IPNetVarP(p, name, "", value, usage)
81+
return p
82+
}
83+
84+
// Like IPNet, but accepts a shorthand letter that can be used after a single dash.
85+
func (f *FlagSet) IPNetP(name, shorthand string, value net.IPNet, usage string) *net.IPNet {
86+
p := new(net.IPNet)
87+
f.IPNetVarP(p, name, shorthand, value, usage)
88+
return p
89+
}
90+
91+
// IPNet defines an net.IPNet flag with specified name, default value, and usage string.
92+
// The return value is the address of an net.IPNet variable that stores the value of the flag.
93+
func IPNet(name string, value net.IPNet, usage string) *net.IPNet {
94+
return CommandLine.IPNetP(name, "", value, usage)
95+
}
96+
97+
// Like IPNet, but accepts a shorthand letter that can be used after a single dash.
98+
func IPNetP(name, shorthand string, value net.IPNet, usage string) *net.IPNet {
99+
return CommandLine.IPNetP(name, shorthand, value, usage)
100+
}

ipnet_test.go

Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,70 @@
1+
package pflag
2+
3+
import (
4+
"fmt"
5+
"net"
6+
"os"
7+
"testing"
8+
)
9+
10+
func setUpIPNet(ip *net.IPNet) *FlagSet {
11+
f := NewFlagSet("test", ContinueOnError)
12+
_, def, _ := net.ParseCIDR("0.0.0.0/0")
13+
f.IPNetVar(ip, "address", *def, "IP Address")
14+
return f
15+
}
16+
17+
func TestIPNet(t *testing.T) {
18+
testCases := []struct {
19+
input string
20+
success bool
21+
expected string
22+
}{
23+
{"0.0.0.0/0", true, "0.0.0.0/0"},
24+
{" 0.0.0.0/0 ", true, "0.0.0.0/0"},
25+
{"1.2.3.4/8", true, "1.0.0.0/8"},
26+
{"127.0.0.1/16", true, "127.0.0.0/16"},
27+
{"255.255.255.255/19", true, "255.255.224.0/19"},
28+
{"255.255.255.255/32", true, "255.255.255.255/32"},
29+
{"", false, ""},
30+
{"/0", false, ""},
31+
{"0", false, ""},
32+
{"0/0", false, ""},
33+
{"localhost/0", false, ""},
34+
{"0.0.0/4", false, ""},
35+
{"0.0.0./8", false, ""},
36+
{"0.0.0.0./12", false, ""},
37+
{"0.0.0.256/16", false, ""},
38+
{"0.0.0.0 /20", false, ""},
39+
{"0.0.0.0/ 24", false, ""},
40+
{"0 . 0 . 0 . 0 / 28", false, ""},
41+
{"0.0.0.0/33", false, ""},
42+
}
43+
44+
devnull, _ := os.Open(os.DevNull)
45+
os.Stderr = devnull
46+
for i := range testCases {
47+
var addr net.IPNet
48+
f := setUpIPNet(&addr)
49+
50+
tc := &testCases[i]
51+
52+
arg := fmt.Sprintf("--address=%s", tc.input)
53+
err := f.Parse([]string{arg})
54+
if err != nil && tc.success == true {
55+
t.Errorf("expected success, got %q", err)
56+
continue
57+
} else if err == nil && tc.success == false {
58+
t.Errorf("expected failure")
59+
continue
60+
} else if tc.success {
61+
ip, err := f.GetIPNet("address")
62+
if err != nil {
63+
t.Errorf("Got error trying to fetch the IP flag: %v", err)
64+
}
65+
if ip.String() != tc.expected {
66+
t.Errorf("expected %q, got %q", tc.expected, ip.String())
67+
}
68+
}
69+
}
70+
}

0 commit comments

Comments
 (0)