diff --git a/flagext/urlescaped.go b/flagext/urlescaped.go new file mode 100644 index 000000000..0bac0eb82 --- /dev/null +++ b/flagext/urlescaped.go @@ -0,0 +1,63 @@ +package flagext + +import "net/url" + +// URLEscaped is a url.URL that can be used as a flag. +// URL value it contains will always be URL escaped and safe. +type URLEscaped struct { + *url.URL +} + +// String implements flag.Value +func (v URLEscaped) String() string { + if v.URL == nil { + return "" + } + return v.URL.String() +} + +// Set implements flag.Value +// Set make sure given URL string is escaped. +func (v *URLEscaped) Set(s string) error { + s = url.QueryEscape(s) + + u, err := url.Parse(s) + if err != nil { + return err + } + v.URL = u + return nil +} + +// UnmarshalYAML implements yaml.Unmarshaler. +func (v *URLEscaped) UnmarshalYAML(unmarshal func(interface{}) error) error { + var s string + if err := unmarshal(&s); err != nil { + return err + } + + // An empty string means no URL has been configured. + if s == "" { + v.URL = nil + return nil + } + + return v.Set(s) +} + +// Marshalyaml Implements yaml.Marshaler. +func (v URLEscaped) MarshalYAML() (interface{}, error) { + if v.URL == nil { + return "", nil + } + + // Mask out passwords when marshalling URLs back to YAML. + u := *v.URL + if u.User != nil { + if _, set := u.User.Password(); set { + u.User = url.UserPassword(u.User.Username(), "********") + } + } + + return u.String(), nil +} diff --git a/flagext/urlescaped_test.go b/flagext/urlescaped_test.go new file mode 100644 index 000000000..12c0ebee5 --- /dev/null +++ b/flagext/urlescaped_test.go @@ -0,0 +1,38 @@ +package flagext + +import ( + "flag" + "fmt" + "net/url" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "gopkg.in/yaml.v2" +) + +const ( + testS3URL = "s3://ASDFGHJIQWETTYUI:Jkasduahdkjh213kj1h31+lkjaflkjzvKASDOasofhjafaKFAF/GoQd@region/bucket_name" +) + +func TestURLEscaped(t *testing.T) { + expected, err := url.Parse(url.QueryEscape(testS3URL)) + require.NoError(t, err) + + // flag + var v URLEscaped + flags := flag.NewFlagSet("test", flag.ExitOnError) + flags.Var(&v, "v", "some secret credentials") + err = flags.Parse([]string{"-v", testS3URL}) + + assert.Equal(t, v.String(), expected.String()) + + // yaml + yv := struct { + S3 URLEscaped `yaml:"s3"` + }{} + err = yaml.Unmarshal([]byte(fmt.Sprintf("s3: %s", testS3URL)), &yv) + assert.NoError(t, err) + assert.Equal(t, yv.S3.String(), expected.String()) + +}