Skip to content

Commit

Permalink
added WithGetenvFunc option so that from tests we do not have to rely…
Browse files Browse the repository at this point in the history
… on global state, defaults to os.Getenv still
  • Loading branch information
tempcke committed Feb 17, 2024
1 parent 21ddd05 commit e2a2f2f
Show file tree
Hide file tree
Showing 3 changed files with 65 additions and 1 deletion.
14 changes: 14 additions & 0 deletions options.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@ type ParseContext struct {
configOpenFunc func(string) (iofs.File, error)
configAllowMissingFile bool
configIgnoreUndefinedFlags bool

getenvFunc func(string) string // alternate implementation of os.Getenv
}

// ConfigFileParseFunc is a function that consumes the provided reader as a config
Expand Down Expand Up @@ -137,3 +139,15 @@ func WithFilesystem(fs iofs.FS) Option {
pc.configOpenFunc = fs.Open
}
}

// WithGetenvFunc is like [WithEnvVars], but with a controlled environment lookup
// The provided function will be used rather than os.Getenv when looking up env vars
// This allows tests to be run using t.Parallel with a controlled environment state
//
// By default os.Getenv is used
func WithGetenvFunc(f func(string) string) Option {
return func(pc *ParseContext) {
pc.envVarEnabled = true
pc.getenvFunc = f
}
}
11 changes: 10 additions & 1 deletion parse.go
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ func parse(fs Flags, args []string, options ...Option) error {
key := getEnvVarKey(name, pc.envVarPrefix)

// Look up the value from the environment.
val := os.Getenv(key)
val := getenv(pc.getenvFunc, key)
if val == "" {
continue
}
Expand Down Expand Up @@ -202,6 +202,15 @@ func parse(fs Flags, args []string, options ...Option) error {
return nil
}

// getenv returns the value of the environment variable with the given key.
// If fn is nil, os.Getenv is used. This is mostly used to inject non-global state from tests.
func getenv(fn func(string) string, key string) string {
if fn != nil {
return fn(key)
}
return os.Getenv(key)
}

//
//
//
Expand Down
41 changes: 41 additions & 0 deletions parse_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import (
"flag"
"os"
"path/filepath"
"strconv"
"testing"
"time"

Expand Down Expand Up @@ -428,3 +429,43 @@ func TestParse_stdfs(t *testing.T) {
t.Errorf("foo: want %q, have %q", want, have)
}
}

func TestParse_Getenv(t *testing.T) {
t.Parallel()
// purpose of this test is to be able to provide a custom Getenv func
// this allows for cleaner tests which can control the env without mutating global state
var (
expectAddr = "foo"
expectCompress = true
expectTransform = true
expectLogLevel = "error"
envMap = map[string]string{
"X_ADDR": expectAddr,
"X_COMPRESS": strconv.FormatBool(expectCompress),
"X_TRANSFORM": strconv.FormatBool(expectTransform),
"X_LOG": expectLogLevel,
}
prefixOption = ff.WithEnvVarPrefix("X")
envOption = ff.WithGetenvFunc(func(key string) string { return envMap[key] })
)
fs := ff.NewFlagSet(t.Name())
var (
addr = fs.String('a', "addr", "", "remote address (repeatable)")
compress = fs.Bool('c', "compress", "enable compression")
transform = fs.Bool('t', "transform", "enable transformation")
loglevel = fs.StringEnum('l', "log", "log level: debug, info, error", "info", "debug", "error")
)
if err := ff.Parse(fs, nil, prefixOption, envOption); err != nil {
t.Fatal(err)
}
assertEqual(t, expectAddr, *addr)
assertEqual(t, expectCompress, *compress)
assertEqual(t, expectTransform, *transform)
assertEqual(t, expectLogLevel, *loglevel)
}
func assertEqual(t *testing.T, expected, actual interface{}) {
t.Helper()
if expected != actual {
t.Errorf("Values not equal \nWant: %v \t%T\nGot: %v \t%T\n", expected, expected, actual, actual)
}
}

0 comments on commit e2a2f2f

Please sign in to comment.