diff --git a/options.go b/options.go index ebd3405..e79afaa 100644 --- a/options.go +++ b/options.go @@ -20,6 +20,8 @@ type ParseContext struct { configOpenFunc func(string) (iofs.File, error) configAllowMissingFile bool configIgnoreUndefinedFlags bool + + getenvFunc func(string) string // alternate implementation of os.Getenv } // ConfigFileParseFunc is a function that consumes the provided reader as a config @@ -137,3 +139,15 @@ func WithFilesystem(fs iofs.FS) Option { pc.configOpenFunc = fs.Open } } + +// WithGetenvFunc is like [WithEnvVars], but with a controlled environment lookup +// The provided function will be used rather than os.Getenv when looking up env vars +// This allows tests to be run using t.Parallel with a controlled environment state +// +// By default os.Getenv is used +func WithGetenvFunc(f func(string) string) Option { + return func(pc *ParseContext) { + pc.envVarEnabled = true + pc.getenvFunc = f + } +} diff --git a/parse.go b/parse.go index bdfb2c7..728e43e 100644 --- a/parse.go +++ b/parse.go @@ -91,7 +91,7 @@ func parse(fs Flags, args []string, options ...Option) error { key := getEnvVarKey(name, pc.envVarPrefix) // Look up the value from the environment. - val := os.Getenv(key) + val := getenv(pc.getenvFunc, key) if val == "" { continue } @@ -202,6 +202,15 @@ func parse(fs Flags, args []string, options ...Option) error { return nil } +// getenv returns the value of the environment variable with the given key. +// If fn is nil, os.Getenv is used. This is mostly used to inject non-global state from tests. +func getenv(fn func(string) string, key string) string { + if fn != nil { + return fn(key) + } + return os.Getenv(key) +} + // // // diff --git a/parse_test.go b/parse_test.go index 62f9b3e..ed21ba9 100644 --- a/parse_test.go +++ b/parse_test.go @@ -5,6 +5,7 @@ import ( "flag" "os" "path/filepath" + "strconv" "testing" "time" @@ -428,3 +429,43 @@ func TestParse_stdfs(t *testing.T) { t.Errorf("foo: want %q, have %q", want, have) } } + +func TestParse_Getenv(t *testing.T) { + t.Parallel() + // purpose of this test is to be able to provide a custom Getenv func + // this allows for cleaner tests which can control the env without mutating global state + var ( + expectAddr = "foo" + expectCompress = true + expectTransform = true + expectLogLevel = "error" + envMap = map[string]string{ + "X_ADDR": expectAddr, + "X_COMPRESS": strconv.FormatBool(expectCompress), + "X_TRANSFORM": strconv.FormatBool(expectTransform), + "X_LOG": expectLogLevel, + } + prefixOption = ff.WithEnvVarPrefix("X") + envOption = ff.WithGetenvFunc(func(key string) string { return envMap[key] }) + ) + fs := ff.NewFlagSet(t.Name()) + var ( + addr = fs.String('a', "addr", "", "remote address (repeatable)") + compress = fs.Bool('c', "compress", "enable compression") + transform = fs.Bool('t', "transform", "enable transformation") + loglevel = fs.StringEnum('l', "log", "log level: debug, info, error", "info", "debug", "error") + ) + if err := ff.Parse(fs, nil, prefixOption, envOption); err != nil { + t.Fatal(err) + } + assertEqual(t, expectAddr, *addr) + assertEqual(t, expectCompress, *compress) + assertEqual(t, expectTransform, *transform) + assertEqual(t, expectLogLevel, *loglevel) +} +func assertEqual(t *testing.T, expected, actual interface{}) { + t.Helper() + if expected != actual { + t.Errorf("Values not equal \nWant: %v \t%T\nGot: %v \t%T\n", expected, expected, actual, actual) + } +}