Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Add the ability to overlay configuration #698

Closed
wants to merge 9 commits into from
43 changes: 39 additions & 4 deletions pkg/config/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,13 @@ package config
import (
"fmt"
"log"
"os"
"reflect"
"strings"
"time"

"github.com/mitchellh/mapstructure"
"github.com/pkg/errors"
"github.com/spf13/viper"
)

Expand Down Expand Up @@ -51,6 +53,7 @@ type Server struct {
Port int `yaml:"port"`
}

// Config describes the complete configuration for a Retina process.
type Config struct {
APIServer Server `yaml:"apiServer"`
LogLevel string `yaml:"logLevel"`
Expand All @@ -68,14 +71,37 @@ type Config struct {
MonitorSockPath string `yaml:"monitorSockPath"`
}

func GetConfig(cfgFilename string) (*Config, error) {
if cfgFilename != "" {
viper.SetConfigFile(cfgFilename)
type FilteredConfig struct {
Filename string
AllowedFields []string
}

func mergeConfig(file FilteredConfig) error {
f, err := os.Open(file.Filename)
if err != nil {
return errors.Wrapf(err, "opening config file %q", file)
}
defer f.Close()

fy, err := NewFilteredYAML(f, file.AllowedFields)
if err != nil {
return errors.Wrap(err, "creating FilteredYAML")
}

err = viper.MergeConfig(fy)
if err != nil {
return errors.Wrap(err, "merging config with viper")
}
return nil
}

func GetConfig(primaryCfg string, overlays ...FilteredConfig) (*Config, error) {
if primaryCfg != "" {
viper.SetConfigFile(primaryCfg)
} else {
viper.SetConfigName("config")
viper.AddConfigPath("/retina/config")
}

viper.SetEnvPrefix("retina")
viper.AutomaticEnv()
// NOTE(mainred): RetinaEndpoint is currently the only supported solution to cache Pod, and before an alternative is implemented,
Expand All @@ -86,6 +112,15 @@ func GetConfig(cfgFilename string) (*Config, error) {
if err != nil {
return nil, fmt.Errorf("fatal error config file: %s", err)
}

// apply overlay configs
for _, file := range overlays {
err := mergeConfig(file) //nolint:govet // shadowing is fine here
if err != nil {
return nil, errors.Wrapf(err, "merging config for %q", file)
}
}

var config Config
decoderConfigOption := func(dc *mapstructure.DecoderConfig) {
dc.DecodeHook = mapstructure.ComposeDecodeHookFunc(
Expand Down
20 changes: 18 additions & 2 deletions pkg/config/config_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ import (
)

func TestGetConfig(t *testing.T) {
c, err := GetConfig("./testwith/config.yaml")
c, err := GetConfig("./testdata/config.yaml")
if err != nil {
t.Fatalf("Expected no error, instead got %+v", err)
}
Expand All @@ -27,7 +27,23 @@ func TestGetConfig(t *testing.T) {
c.RemoteContext ||
c.EnableAnnotations ||
c.DataAggregationLevel != Low {
t.Fatalf("Expeted config should be same as ./testwith/config.yaml; instead got %+v", c)
t.Fatalf("Expeted config should be same as ./testdata/config.yaml; instead got %+v", c)
}
}

func TestGetConfigOverlay(t *testing.T) {
c, err := GetConfig("./testdata/config.yaml", FilteredConfig{
Filename: "./testdata/overlay.yaml",
AllowedFields: []string{
"logLevel",
},
})
if err != nil {
t.Fatal("err getting config: err:", err)
}

if c.LogLevel != "debug" {
t.Error("expected LogLevel to be overridden to debug, but found:", c.LogLevel)
}
}

Expand Down
62 changes: 62 additions & 0 deletions pkg/config/filtered_yaml.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
package config

import (
"bytes"
"io"

"github.com/pkg/errors"
"gopkg.in/yaml.v2"
)

func NewFilteredYAML(source io.ReadCloser, allowedFields []string) (*FilteredYAML, error) {
f := &FilteredYAML{
YAML: source,
AllowedFields: allowedFields,
buf: &bytes.Buffer{},
}

if err := f.filter(); err != nil {
return nil, errors.Wrap(err, "filtering yaml")
}

return f, nil
}

// FilteredYAML is a YAML config that is restricted to a specified allowlist of
// fields. Any additional fields found will be removed, such that the resulting
// configuration is the subset of fields found in the allowlist.
type FilteredYAML struct {
YAML io.ReadCloser // the input YAML
AllowedFields []string // the set of allowed fields in the resulting YAML
buf *bytes.Buffer
}

func (f *FilteredYAML) filter() error {
defer f.YAML.Close()
f.buf = bytes.NewBufferString("")

decoded := make(map[string]any)
err := yaml.NewDecoder(f.YAML).Decode(&decoded)
if err != nil && !errors.Is(err, io.EOF) {
return errors.Wrap(err, "reading input YAML")
}

filtered := make(map[string]any, len(decoded))
for _, field := range f.AllowedFields {
if val, ok := decoded[field]; ok {
filtered[field] = val
}
}

err = yaml.NewEncoder(f.buf).Encode(filtered)
if err != nil {
return errors.Wrap(err, "remarshaling filtered yaml")
}
return nil
}

// Read extracts the subset of YAML matching AllowedFields and writes it to the
// supplied buffer.
func (f *FilteredYAML) Read(out []byte) (int, error) {
return f.buf.Read(out) //nolint:wrapcheck // there's no value in wrapping this
}
60 changes: 60 additions & 0 deletions pkg/config/filtered_yaml_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
package config_test

import (
"io"
"strings"
"testing"

"github.com/google/go-cmp/cmp"

"github.com/microsoft/retina/pkg/config"
)

func TestFilteredYAML(t *testing.T) {
tests := []struct {
name string
in string
allowed []string
exp string
}{
{
"empty",
"",
[]string{},
"{}\n",
},
{
"one field",
"foo: bar\n",
[]string{"foo"},
"foo: bar\n",
},
{
"two fields",
"foo: bar\nbaz: quux\n",
[]string{"foo"},
"foo: bar\n",
},
}

for _, test := range tests {
test := test
t.Run(test.name, func(t *testing.T) {
t.Parallel()

fy, err := config.NewFilteredYAML(io.NopCloser(strings.NewReader(test.in)), test.allowed)
if err != nil {
t.Fatal("unexpected error creating filtered yaml: err:", err)
}

got, err := io.ReadAll(fy)
if err != nil {
t.Fatal("unexpected error: err:", err)
}

if !cmp.Equal(test.exp, string(got)) {
t.Fatal("yaml differs from expected: diff:", cmp.Diff(test.exp, string(got)))
}
})
}
}
File renamed without changes.
2 changes: 2 additions & 0 deletions pkg/config/testdata/overlay.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
---
logLevel: debug
Loading