diff --git a/README.md b/README.md index 9983399..15082bb 100644 --- a/README.md +++ b/README.md @@ -30,6 +30,7 @@ type config struct { Port int `env:"PORT" required:"true"` Peers []string `env:"PEERS"` // you can use `delimiter` tag to specify separator, for example `delimiter:" "` ConnectionTimeout time.Duration `env:"TIMEOUT" default:"10s"` + LogLevel string `env:"LOG_LEVEL" enum:"debug info error" delimiter:" "` // the delimiter tag applies to enum as well } func main() { diff --git a/env.go b/env.go index 1c43c92..24ecd4e 100644 --- a/env.go +++ b/env.go @@ -107,6 +107,10 @@ func setField(t reflect.StructField, v reflect.Value, value string) (err error) return setSlice(t, v, value) } + if err = checkEnum(t, value); err != nil { + return fmt.Errorf("error setting %q: %v", t.Name, err) + } + if err = setBuiltInField(v, value); err != nil { return fmt.Errorf("error setting %q: %v", t.Name, err) } diff --git a/env_test.go b/env_test.go index ef7463c..2180b2c 100644 --- a/env_test.go +++ b/env_test.go @@ -388,6 +388,90 @@ func TestEnvWithDefaultWhenProvided(t *testing.T) { Equals(t, "goodbye", config.Prop) } +func TestEnvEnumUnmatched(t *testing.T) { + os.Setenv("PROP", "foo") + + config := struct { + Prop string `env:"PROP" enum:"1,2,3"` + }{} + + err := Set(&config) + ErrorNotNil(t, err) + Assert(t, strings.HasPrefix(err.Error(), `error setting "Prop": "foo" is not a member of [1,2,3]`)) +} + +func TestEnvEnumMatchedString(t *testing.T) { + os.Setenv("PROP", "foo") + + config := struct { + Prop string `env:"PROP" enum:"foo,bar,baz"` + }{} + + err := Set(&config) + ErrorNil(t, err) + Assert(t, config.Prop == "foo") +} + +func TestEnvEnumMatchedInteger(t *testing.T) { + os.Setenv("PROP", "1") + + config := struct { + Prop int `env:"PROP" enum:"1,2,3"` + }{} + + err := Set(&config) + ErrorNil(t, err) + Assert(t, config.Prop == 1) +} + +func TestEnvEnumMatchedStringSlice(t *testing.T) { + os.Setenv("PROP", "foo,bar,baz") + + config := struct { + Prop []string `env:"PROP" enum:"foo,bar,baz"` + }{} + + err := Set(&config) + ErrorNil(t, err) + Equals(t, []string{"foo", "bar", "baz"}, config.Prop) +} + +func TestEnvEnumUnmatchedStringSlice(t *testing.T) { + os.Setenv("PROP", "foo,bar,baz") + + config := struct { + Prop []string `env:"PROP" enum:"foo,baz"` + }{} + + err := Set(&config) + ErrorNotNil(t, err) + Equals(t, `error setting "Prop": "bar" is not a member of [foo,baz]`, err.Error()) +} + +func TestEnvEnumMatchedIntSlice(t *testing.T) { + os.Setenv("PROP", "1,3") + + config := struct { + Prop []int `env:"PROP" enum:"1,3,5"` + }{} + + err := Set(&config) + ErrorNil(t, err) + Equals(t, []int{1, 3}, config.Prop) +} + +func TestEnvEnumUnmatchedIntSlice(t *testing.T) { + os.Setenv("PROP", "1,2,3") + + config := struct { + Prop []int `env:"PROP" enum:"1,3,5"` + }{} + + err := Set(&config) + ErrorNotNil(t, err) + Equals(t, `error setting "Prop": "2" is not a member of [1,3,5]`, err.Error()) +} + func TestEnvWithDefaultWhenMissing(t *testing.T) { unsetEnvironment() diff --git a/set.go b/set.go index d6be59a..d24bbef 100644 --- a/set.go +++ b/set.go @@ -3,6 +3,7 @@ package env import ( "fmt" "reflect" + "slices" "strconv" "strings" "time" @@ -106,6 +107,12 @@ func setSlice(t reflect.StructField, v reflect.Value, value string) (err error) return } + for _, rawValue := range rawValues { + if err = checkEnum(t, rawValue); err != nil { + return fmt.Errorf("error setting %q: %v", t.Name, err) + } + } + sliceValue, err := makeSlice(v, len(rawValues)) if err != nil { return @@ -179,3 +186,18 @@ func getDelimiter(t reflect.StructField) string { } return "," } + +func checkEnum(t reflect.StructField, value string) (err error) { + rawChoices, ok := t.Tag.Lookup("enum") + if !ok { + return + } + + delimiter := getDelimiter(t) + choices := split(rawChoices, delimiter) + if !slices.Contains(choices, value) { + return fmt.Errorf(`"%s" is not a member of [%s]`, value, strings.Join(choices, ",")) + } + + return +}