Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 21 additions & 8 deletions core/pkg/store/query.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,14 +33,23 @@ type Selector struct {
indexMap map[string]string
}

var validSelectorKeys = map[string]struct{}{
flagSetIdIndex: {},
sourceIndex: {},
}

// NewSelector creates a new Selector from a selector expression string.
// #1708 Until we decide on the Selector syntax, only a single key=value pair is supported
// For example, to select flags from source "./mySource" or flagSetId "1234", use the expressions:
// "source=./mySource" or "flagSetId=1234"
func NewSelector(selectorExpression string) Selector {
return Selector{
indexMap: expressionToMap(selectorExpression),
func NewSelector(selectorExpression string) (Selector, error) {
m := expressionToMap(selectorExpression)
for key := range m {
if _, ok := validSelectorKeys[key]; !ok {
return Selector{}, fmt.Errorf("invalid selector key %q, valid keys: %q, %q", key, flagSetIdIndex, sourceIndex)
}
}
return Selector{indexMap: m}, nil
}

func expressionToMap(sExp string) map[string]string {
Expand Down Expand Up @@ -70,13 +79,17 @@ func expressionToMap(sExp string) map[string]string {
return selectorMap
}

// WithIndex creates a new Selector from the current Selector and adds the given key-value-pair
func (s Selector) WithIndex(key string, value string) Selector {
func (s Selector) WithSource(source string) Selector { return s.withIndex(sourceIndex, source) }
func (s Selector) WithFlagSetId(id string) Selector { return s.withIndex(flagSetIdIndex, id) }
func (s Selector) withKey(key string) Selector { return s.withIndex(keyIndex, key) }

func (s Selector) withIndex(key, value string) Selector {
m := maps.Clone(s.indexMap)
m[key] = value
return Selector{
indexMap: m,
if m == nil {
m = make(map[string]string, 1)
}
m[key] = value
return Selector{indexMap: m}
}

func (s *Selector) IsEmpty() bool {
Expand Down
37 changes: 26 additions & 11 deletions core/pkg/store/query_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -50,22 +50,27 @@ func TestSelector_IsEmpty(t *testing.T) {
}
}

func TestSelector_WithIndex(t *testing.T) {
oldS := Selector{indexMap: map[string]string{"source": "abc"}}
newS := oldS.WithIndex("flagSetId", "1234")
func TestSelector_WithSourceAndFlagSetId(t *testing.T) {
s := Selector{}.WithSource("abc")
if s.indexMap[sourceIndex] != "abc" {
t.Errorf("WithSource did not set source")
}

if newS.indexMap["source"] != "abc" {
t.Errorf("WithIndex did not preserve existing keys")
s2 := s.WithFlagSetId("1234")
if s2.indexMap[sourceIndex] != "abc" {
t.Errorf("WithFlagSetId did not preserve source")
}
if newS.indexMap["flagSetId"] != "1234" {
t.Errorf("WithIndex did not add new key")
if s2.indexMap[flagSetIdIndex] != "1234" {
t.Errorf("WithFlagSetId did not set flagSetId")
}

// Ensure original is unchanged
if _, ok := oldS.indexMap["flagSetId"]; ok {
t.Errorf("WithIndex mutated original selector")
if _, ok := s.indexMap[flagSetIdIndex]; ok {
t.Errorf("WithFlagSetId mutated original selector")
}
}


func TestSelector_ToQuery(t *testing.T) {
tests := []struct {
name string
Expand Down Expand Up @@ -176,6 +181,7 @@ func TestNewSelector(t *testing.T) {
name string
input string
wantMap map[string]string
wantErr bool
}{
// #1708 Until we decide on the Selector syntax, only a single key=value pair is supported
/*
Expand Down Expand Up @@ -205,12 +211,21 @@ func TestNewSelector(t *testing.T) {
input: "",
wantMap: map[string]string{},
},
{
name: "invalid key",
input: "flagSetIds=abc",
wantErr: true,
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
s := NewSelector(tt.input)
if !reflect.DeepEqual(s.indexMap, tt.wantMap) {
s, err := NewSelector(tt.input)
if (err != nil) != tt.wantErr {
t.Errorf("NewSelector(%q) error = %v, wantErr %v", tt.input, err, tt.wantErr)
return
}
if !tt.wantErr && !reflect.DeepEqual(s.indexMap, tt.wantMap) {
t.Errorf("NewSelector(%q) indexMap = %v, want %v", tt.input, s.indexMap, tt.wantMap)
}
})
Expand Down
6 changes: 3 additions & 3 deletions core/pkg/store/store.go
Original file line number Diff line number Diff line change
Expand Up @@ -153,7 +153,7 @@ func (s *Store) Get(_ context.Context, key string, selector *Selector) (model.Fl

// if present, use the selector to query the flags
if !selector.IsEmpty() {
selector := selector.WithIndex("key", key)
selector := selector.withKey(key)
indexId, constraints := selector.ToQuery()
s.logger.Debug(fmt.Sprintf("getting flag with query: %s, %v", indexId, constraints))
raw, err := txn.First(flagsTable, indexId, constraints...)
Expand Down Expand Up @@ -274,7 +274,7 @@ func (s *Store) Update(
seenFlagSetIds[id.flagSetId] = struct{}{}
}
for fsi := range seenFlagSetIds {
sel := NewSelector(flagSetIdIndex+"="+fsi).WithIndex(sourceIndex, source)
sel := Selector{}.WithFlagSetId(fsi).WithSource(source)
indexId, constraints := sel.ToQuery()
it, err := txn.Get(flagsTable, indexId, constraints...)
if err != nil {
Expand All @@ -284,7 +284,7 @@ func (s *Store) Update(
oldFlags = append(oldFlags, s.collect(it)...)
}
} else {
sel := NewSelector(sourceIndex + "=" + source)
sel := Selector{}.WithSource(source)
indexId, constraints := sel.ToQuery()
it, err := txn.Get(flagsTable, indexId, constraints...)
if err != nil {
Expand Down
24 changes: 12 additions & 12 deletions core/pkg/store/store_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -238,8 +238,8 @@ func TestGet(t *testing.T) {

sources := []string{sourceA.Name, sourceB.Name, sourceC.Name}

sourceASelector := NewSelector("source=" + sourceA.Name)
flagSetIdCSelector := NewSelector("flagSetId=" + flagSetIdC)
sourceASelector := Selector{}.WithSource(sourceA.Name)
flagSetIdCSelector := Selector{}.WithFlagSetId(flagSetIdC)

t.Parallel()
tests := []struct {
Expand Down Expand Up @@ -377,10 +377,10 @@ func TestGetAllNoWatcher(t *testing.T) {

sources := []string{sourceA.Name, sourceB.Name, sourceC.Name}

sourceASelector := NewSelector("source=" + sourceA.Name)
flagSetIdCSelector := NewSelector("flagSetId=" + flagSetIdC)
sourceASelector := Selector{}.WithSource(sourceA.Name)
flagSetIdCSelector := Selector{}.WithFlagSetId(flagSetIdC)
// #1708 Until we decide on the Selector syntax, only a single key=value pair is supported
//flagSetIdAndCSelector := NewSelector("flagSetId=" + flagSetIdC + ",source=" + sourceC.Name)
//flagSetIdAndCSelector := Selector{}.WithFlagSetId(flagSetIdC).withIndex(sourceIndex, sourceC.Name)

t.Parallel()
tests := []struct {
Expand Down Expand Up @@ -506,10 +506,10 @@ func TestWatch(t *testing.T) {
pauseTime := 100 * time.Millisecond // time for updates to settle
timeout := 1000 * time.Millisecond // time to make sure we get enough updates, and no extras

sourceASelector := NewSelector("source=" + sourceA)
flagSetIdCSelector := NewSelector("flagSetId=" + myFlagSetId)
emptySelector := NewSelector("")
sourceCSelector := NewSelector("source=" + sourceC)
sourceASelector := Selector{}.WithSource(sourceA)
flagSetIdCSelector := Selector{}.WithFlagSetId(myFlagSetId)
emptySelector := Selector{}
sourceCSelector := Selector{}.WithSource(sourceC)

tests := []struct {
name string
Expand Down Expand Up @@ -786,15 +786,15 @@ func TestQueryMetadata(t *testing.T) {
// #1708 Until we decide on the Selector syntax, only a single key=value pair is supported
// these tests should then also cover more complex selectors

selector := NewSelector("flagSetId=" + nonExistingFlagSetId)
selector := Selector{}.WithFlagSetId(nonExistingFlagSetId)
_, metadata, _ := store.GetAll(context.Background(), &selector)
assert.Equal(t, metadata, model.Metadata{"flagSetId": nonExistingFlagSetId}, "metadata did not match expected")

selector = NewSelector("flagSetId=" + nonExistingFlagSetId)
selector = Selector{}.WithFlagSetId(nonExistingFlagSetId)
_, metadata, _ = store.Get(context.Background(), "key", &selector)
assert.Equal(t, metadata, model.Metadata{"flagSetId": nonExistingFlagSetId}, "metadata did not match expected")

selector = NewSelector("source=" + otherSource)
selector = Selector{}.WithSource(otherSource)
_, metadata, _ = store.Get(context.Background(), "key", &selector)
assert.Equal(t, metadata, model.Metadata{"source": otherSource}, "metadata did not match expected")
}
13 changes: 13 additions & 0 deletions flagd/pkg/service/flag-evaluation/context_utils.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,21 @@ package service

import (
"net/http"

"connectrpc.com/connect"
flagdService "github.com/open-feature/flagd/flagd/pkg/service"
"github.com/open-feature/flagd/core/pkg/store"
)

func selectorFromHeader(header http.Header) (store.Selector, error) {
expr := header.Get(flagdService.FLAGD_SELECTOR_HEADER)
s, err := store.NewSelector(expr)
if err != nil {
return store.Selector{}, connect.NewError(connect.CodeInvalidArgument, err)
}
return s, nil
}

// MergeContextsAndHeaders merges evaluation contexts with static context values and header-based context.
// highest priority > header-context-from-cli > static-context-from-cli > request-context > lowest priority
// Header names are matched case-insensitively according to HTTP specification.
Expand Down
53 changes: 33 additions & 20 deletions flagd/pkg/service/flag-evaluation/flag_evaluator.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@ import (
"github.com/open-feature/flagd/core/pkg/service"
"github.com/open-feature/flagd/core/pkg/store"
"github.com/open-feature/flagd/core/pkg/telemetry"
flagdService "github.com/open-feature/flagd/flagd/pkg/service"
"github.com/rs/xid"
"go.opentelemetry.io/otel"
"go.opentelemetry.io/otel/attribute"
Expand Down Expand Up @@ -75,8 +74,10 @@ func (s *OldFlagEvaluationService) ResolveAll(
Flags: make(map[string]*schemaV1.AnyFlag),
}

selectorExpression := req.Header().Get(flagdService.FLAGD_SELECTOR_HEADER)
selector := store.NewSelector(selectorExpression)
selector, err := selectorFromHeader(req.Header())
if err != nil {
return nil, err
}
ctx = context.WithValue(ctx, store.SelectorContextKey{}, selector)

values, _, err := s.eval.ResolveAllValues(ctx, reqID, mergeContexts(req.Msg.GetContext().AsMap(), s.contextValues, req.Header(), make(map[string]string)))
Expand Down Expand Up @@ -142,8 +143,10 @@ func (s *OldFlagEvaluationService) EventStream(
s.logger.Debug(fmt.Sprintf("starting event stream for request"))

requestNotificationChan := make(chan service.Notification, 1)
selectorExpression := req.Header().Get(flagdService.FLAGD_SELECTOR_HEADER)
selector := store.NewSelector(selectorExpression)
selector, err := selectorFromHeader(req.Header())
if err != nil {
return err
}
s.eventingConfiguration.Subscribe(ctx, req, &selector, requestNotificationChan)
defer s.eventingConfiguration.Unsubscribe(req)

Expand Down Expand Up @@ -185,11 +188,13 @@ func (s *OldFlagEvaluationService) ResolveBoolean(
ctx, span := s.flagEvalTracer.Start(ctx, "resolveBoolean", trace.WithSpanKind(trace.SpanKindServer))
defer span.End()
res := connect.NewResponse(&schemaV1.ResolveBooleanResponse{})
selectorExpression := req.Header().Get(flagdService.FLAGD_SELECTOR_HEADER)
selector := store.NewSelector(selectorExpression)
selector, err := selectorFromHeader(req.Header())
if err != nil {
return nil, err
}
ctx = context.WithValue(ctx, store.SelectorContextKey{}, selector)

err := resolve[bool](
err = resolve[bool](
ctx,
s.logger,
s.eval.ResolveBooleanValue,
Expand Down Expand Up @@ -217,12 +222,14 @@ func (s *OldFlagEvaluationService) ResolveString(
ctx, span := s.flagEvalTracer.Start(ctx, "resolveString", trace.WithSpanKind(trace.SpanKindServer))
defer span.End()

selectorExpression := req.Header().Get(flagdService.FLAGD_SELECTOR_HEADER)
selector := store.NewSelector(selectorExpression)
selector, err := selectorFromHeader(req.Header())
if err != nil {
return nil, err
}
ctx = context.WithValue(ctx, store.SelectorContextKey{}, selector)

res := connect.NewResponse(&schemaV1.ResolveStringResponse{})
err := resolve[string](
err = resolve[string](
ctx,
s.logger,
s.eval.ResolveStringValue,
Expand Down Expand Up @@ -250,12 +257,14 @@ func (s *OldFlagEvaluationService) ResolveInt(
ctx, span := s.flagEvalTracer.Start(ctx, "resolveInt", trace.WithSpanKind(trace.SpanKindServer))
defer span.End()

selectorExpression := req.Header().Get(flagdService.FLAGD_SELECTOR_HEADER)
selector := store.NewSelector(selectorExpression)
selector, err := selectorFromHeader(req.Header())
if err != nil {
return nil, err
}
ctx = context.WithValue(ctx, store.SelectorContextKey{}, selector)

res := connect.NewResponse(&schemaV1.ResolveIntResponse{})
err := resolve[int64](
err = resolve[int64](
ctx,
s.logger,
s.eval.ResolveIntValue,
Expand Down Expand Up @@ -283,12 +292,14 @@ func (s *OldFlagEvaluationService) ResolveFloat(
ctx, span := s.flagEvalTracer.Start(ctx, "resolveFloat", trace.WithSpanKind(trace.SpanKindServer))
defer span.End()

selectorExpression := req.Header().Get(flagdService.FLAGD_SELECTOR_HEADER)
selector := store.NewSelector(selectorExpression)
selector, err := selectorFromHeader(req.Header())
if err != nil {
return nil, err
}
ctx = context.WithValue(ctx, store.SelectorContextKey{}, selector)

res := connect.NewResponse(&schemaV1.ResolveFloatResponse{})
err := resolve[float64](
err = resolve[float64](
ctx,
s.logger,
s.eval.ResolveFloatValue,
Expand Down Expand Up @@ -316,12 +327,14 @@ func (s *OldFlagEvaluationService) ResolveObject(
ctx, span := s.flagEvalTracer.Start(ctx, "resolveObject", trace.WithSpanKind(trace.SpanKindServer))
defer span.End()

selectorExpression := req.Header().Get(flagdService.FLAGD_SELECTOR_HEADER)
selector := store.NewSelector(selectorExpression)
selector, err := selectorFromHeader(req.Header())
if err != nil {
return nil, err
}
ctx = context.WithValue(ctx, store.SelectorContextKey{}, selector)

res := connect.NewResponse(&schemaV1.ResolveObjectResponse{})
err := resolve[map[string]any](
err = resolve[map[string]any](
ctx,
s.logger,
s.eval.ResolveObjectValue,
Expand Down
Loading
Loading