diff --git a/api/internal/conf/conf.go b/api/internal/conf/conf.go index 3879eafe80..ab859ba39d 100644 --- a/api/internal/conf/conf.go +++ b/api/internal/conf/conf.go @@ -22,10 +22,12 @@ import ( "io/ioutil" "os" "path/filepath" + "reflect" "runtime" "strings" "github.com/gorilla/sessions" + "github.com/mitchellh/mapstructure" "github.com/spf13/viper" "github.com/tidwall/gjson" "golang.org/x/oauth2" @@ -184,6 +186,16 @@ func InitConf() { initSchema() } +func unmarshalConfig() (*Config, error) { + config := Config{} + err := viper.Unmarshal(&config, viper.DecodeHook(substituteEnvironmentVariables())) + if err != nil { + return nil, err + } + + return &config, nil +} + func setupConfig() { // setup config file path if ConfigFile == "" { @@ -204,8 +216,7 @@ func setupConfig() { } // unmarshal config - config := Config{} - err := viper.Unmarshal(&config) + config, err := unmarshalConfig() if err != nil { panic(fmt.Sprintf("fail to unmarshal configuration: %s, err: %s", ConfigFile, err.Error())) } @@ -432,3 +443,18 @@ func initSecurity(conf Security) { ContentSecurityPolicy: DefaultCSP, } } + +// substituteEnvironmentVariables substitutes environment variables in the configuration with real values +func substituteEnvironmentVariables() mapstructure.DecodeHookFuncKind { + return func( + f reflect.Kind, + t reflect.Kind, + data interface{}, + ) (interface{}, error) { + if f != reflect.String { + return data, nil + } + + return os.ExpandEnv(data.(string)), nil + } +} diff --git a/api/internal/conf/conf_test.go b/api/internal/conf/conf_test.go index 1a1a1c3ef3..2f05e90f27 100644 --- a/api/internal/conf/conf_test.go +++ b/api/internal/conf/conf_test.go @@ -1,9 +1,12 @@ package conf import ( + "bytes" "encoding/json" + "os" "testing" + "github.com/spf13/viper" "github.com/stretchr/testify/assert" ) @@ -61,3 +64,64 @@ func Test_mergeSchema(t *testing.T) { }) } } + +func Test_unmarshalConfig(t *testing.T) { + tests := []struct { + name string + init func() + config []byte + assert func(*Config, *testing.T) + }{ + { + name: "should correctly parse config without environment variables", + init: func() {}, + config: []byte("conf:\n listen:\n port: \"9000\""), + assert: func(config *Config, t *testing.T) { + assert.Equal(t, 9000, config.Conf.Listen.Port) + }, + }, + { + name: "should correctly substitute int port from environment variables", + init: func() { + os.Setenv("PORT", "8080") + }, + config: []byte("conf:\n listen:\n port: \"$PORT\""), + assert: func(config *Config, t *testing.T) { + assert.Equal(t, 8080, config.Conf.Listen.Port) + }, + }, + { + name: "should correctly substitute string etcd endpoint from environment variables", + init: func() { + os.Setenv("ETCD_ENDPOINT", "127.0.0.1:2379") + }, + config: []byte("conf:\n etcd:\n endpoints:\n - $ETCD_ENDPOINT"), + assert: func(config *Config, t *testing.T) { + assert.Equal(t, "127.0.0.1:2379", config.Conf.Etcd.Endpoints[0]) + }, + }, + } + + for _, tt := range tests { + + t.Run(tt.name, func(t *testing.T) { + viper.SetConfigType("yaml") + viper.AutomaticEnv() + + tt.init() + + err := viper.ReadConfig(bytes.NewBuffer(tt.config)) + if err != nil { + t.Errorf("unable to read config: %v", err) + return + } + config, err := unmarshalConfig() + if err != nil { + t.Errorf("unable to unmarshall config: %v", err) + return + } + + tt.assert(config, t) + }) + } +}