diff --git a/core/pkg/store/query.go b/core/pkg/store/query.go index 478440a8b..e89c52250 100644 --- a/core/pkg/store/query.go +++ b/core/pkg/store/query.go @@ -33,14 +33,23 @@ type Selector struct { indexMap map[string]string } +var validSelectorKeys = map[string]struct{}{ + flagSetIdIndex: {}, + sourceIndex: {}, +} + // NewSelector creates a new Selector from a selector expression string. // #1708 Until we decide on the Selector syntax, only a single key=value pair is supported // For example, to select flags from source "./mySource" or flagSetId "1234", use the expressions: // "source=./mySource" or "flagSetId=1234" -func NewSelector(selectorExpression string) Selector { - return Selector{ - indexMap: expressionToMap(selectorExpression), +func NewSelector(selectorExpression string) (Selector, error) { + m := expressionToMap(selectorExpression) + for key := range m { + if _, ok := validSelectorKeys[key]; !ok { + return Selector{}, fmt.Errorf("invalid selector key %q, valid keys: %q, %q", key, flagSetIdIndex, sourceIndex) + } } + return Selector{indexMap: m}, nil } func expressionToMap(sExp string) map[string]string { @@ -70,13 +79,17 @@ func expressionToMap(sExp string) map[string]string { return selectorMap } -// WithIndex creates a new Selector from the current Selector and adds the given key-value-pair -func (s Selector) WithIndex(key string, value string) Selector { +func (s Selector) WithSource(source string) Selector { return s.withIndex(sourceIndex, source) } +func (s Selector) WithFlagSetId(id string) Selector { return s.withIndex(flagSetIdIndex, id) } +func (s Selector) withKey(key string) Selector { return s.withIndex(keyIndex, key) } + +func (s Selector) withIndex(key, value string) Selector { m := maps.Clone(s.indexMap) - m[key] = value - return Selector{ - indexMap: m, + if m == nil { + m = make(map[string]string, 1) } + m[key] = value + return Selector{indexMap: m} } func (s *Selector) IsEmpty() bool { diff --git a/core/pkg/store/query_test.go b/core/pkg/store/query_test.go index 042b3248b..76d3a7145 100644 --- a/core/pkg/store/query_test.go +++ b/core/pkg/store/query_test.go @@ -50,22 +50,27 @@ func TestSelector_IsEmpty(t *testing.T) { } } -func TestSelector_WithIndex(t *testing.T) { - oldS := Selector{indexMap: map[string]string{"source": "abc"}} - newS := oldS.WithIndex("flagSetId", "1234") +func TestSelector_WithSourceAndFlagSetId(t *testing.T) { + s := Selector{}.WithSource("abc") + if s.indexMap[sourceIndex] != "abc" { + t.Errorf("WithSource did not set source") + } - if newS.indexMap["source"] != "abc" { - t.Errorf("WithIndex did not preserve existing keys") + s2 := s.WithFlagSetId("1234") + if s2.indexMap[sourceIndex] != "abc" { + t.Errorf("WithFlagSetId did not preserve source") } - if newS.indexMap["flagSetId"] != "1234" { - t.Errorf("WithIndex did not add new key") + if s2.indexMap[flagSetIdIndex] != "1234" { + t.Errorf("WithFlagSetId did not set flagSetId") } + // Ensure original is unchanged - if _, ok := oldS.indexMap["flagSetId"]; ok { - t.Errorf("WithIndex mutated original selector") + if _, ok := s.indexMap[flagSetIdIndex]; ok { + t.Errorf("WithFlagSetId mutated original selector") } } + func TestSelector_ToQuery(t *testing.T) { tests := []struct { name string @@ -176,6 +181,7 @@ func TestNewSelector(t *testing.T) { name string input string wantMap map[string]string + wantErr bool }{ // #1708 Until we decide on the Selector syntax, only a single key=value pair is supported /* @@ -205,12 +211,21 @@ func TestNewSelector(t *testing.T) { input: "", wantMap: map[string]string{}, }, + { + name: "invalid key", + input: "flagSetIds=abc", + wantErr: true, + }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - s := NewSelector(tt.input) - if !reflect.DeepEqual(s.indexMap, tt.wantMap) { + s, err := NewSelector(tt.input) + if (err != nil) != tt.wantErr { + t.Errorf("NewSelector(%q) error = %v, wantErr %v", tt.input, err, tt.wantErr) + return + } + if !tt.wantErr && !reflect.DeepEqual(s.indexMap, tt.wantMap) { t.Errorf("NewSelector(%q) indexMap = %v, want %v", tt.input, s.indexMap, tt.wantMap) } }) diff --git a/core/pkg/store/store.go b/core/pkg/store/store.go index 48b80b107..5cc53c10f 100644 --- a/core/pkg/store/store.go +++ b/core/pkg/store/store.go @@ -153,7 +153,7 @@ func (s *Store) Get(_ context.Context, key string, selector *Selector) (model.Fl // if present, use the selector to query the flags if !selector.IsEmpty() { - selector := selector.WithIndex("key", key) + selector := selector.withKey(key) indexId, constraints := selector.ToQuery() s.logger.Debug(fmt.Sprintf("getting flag with query: %s, %v", indexId, constraints)) raw, err := txn.First(flagsTable, indexId, constraints...) @@ -274,7 +274,7 @@ func (s *Store) Update( seenFlagSetIds[id.flagSetId] = struct{}{} } for fsi := range seenFlagSetIds { - sel := NewSelector(flagSetIdIndex+"="+fsi).WithIndex(sourceIndex, source) + sel := Selector{}.WithFlagSetId(fsi).WithSource(source) indexId, constraints := sel.ToQuery() it, err := txn.Get(flagsTable, indexId, constraints...) if err != nil { @@ -284,7 +284,7 @@ func (s *Store) Update( oldFlags = append(oldFlags, s.collect(it)...) } } else { - sel := NewSelector(sourceIndex + "=" + source) + sel := Selector{}.WithSource(source) indexId, constraints := sel.ToQuery() it, err := txn.Get(flagsTable, indexId, constraints...) if err != nil { diff --git a/core/pkg/store/store_test.go b/core/pkg/store/store_test.go index a729c4e49..693c021f6 100644 --- a/core/pkg/store/store_test.go +++ b/core/pkg/store/store_test.go @@ -238,8 +238,8 @@ func TestGet(t *testing.T) { sources := []string{sourceA.Name, sourceB.Name, sourceC.Name} - sourceASelector := NewSelector("source=" + sourceA.Name) - flagSetIdCSelector := NewSelector("flagSetId=" + flagSetIdC) + sourceASelector := Selector{}.WithSource(sourceA.Name) + flagSetIdCSelector := Selector{}.WithFlagSetId(flagSetIdC) t.Parallel() tests := []struct { @@ -377,10 +377,10 @@ func TestGetAllNoWatcher(t *testing.T) { sources := []string{sourceA.Name, sourceB.Name, sourceC.Name} - sourceASelector := NewSelector("source=" + sourceA.Name) - flagSetIdCSelector := NewSelector("flagSetId=" + flagSetIdC) + sourceASelector := Selector{}.WithSource(sourceA.Name) + flagSetIdCSelector := Selector{}.WithFlagSetId(flagSetIdC) // #1708 Until we decide on the Selector syntax, only a single key=value pair is supported - //flagSetIdAndCSelector := NewSelector("flagSetId=" + flagSetIdC + ",source=" + sourceC.Name) + //flagSetIdAndCSelector := Selector{}.WithFlagSetId(flagSetIdC).withIndex(sourceIndex, sourceC.Name) t.Parallel() tests := []struct { @@ -506,10 +506,10 @@ func TestWatch(t *testing.T) { pauseTime := 100 * time.Millisecond // time for updates to settle timeout := 1000 * time.Millisecond // time to make sure we get enough updates, and no extras - sourceASelector := NewSelector("source=" + sourceA) - flagSetIdCSelector := NewSelector("flagSetId=" + myFlagSetId) - emptySelector := NewSelector("") - sourceCSelector := NewSelector("source=" + sourceC) + sourceASelector := Selector{}.WithSource(sourceA) + flagSetIdCSelector := Selector{}.WithFlagSetId(myFlagSetId) + emptySelector := Selector{} + sourceCSelector := Selector{}.WithSource(sourceC) tests := []struct { name string @@ -786,15 +786,15 @@ func TestQueryMetadata(t *testing.T) { // #1708 Until we decide on the Selector syntax, only a single key=value pair is supported // these tests should then also cover more complex selectors - selector := NewSelector("flagSetId=" + nonExistingFlagSetId) + selector := Selector{}.WithFlagSetId(nonExistingFlagSetId) _, metadata, _ := store.GetAll(context.Background(), &selector) assert.Equal(t, metadata, model.Metadata{"flagSetId": nonExistingFlagSetId}, "metadata did not match expected") - selector = NewSelector("flagSetId=" + nonExistingFlagSetId) + selector = Selector{}.WithFlagSetId(nonExistingFlagSetId) _, metadata, _ = store.Get(context.Background(), "key", &selector) assert.Equal(t, metadata, model.Metadata{"flagSetId": nonExistingFlagSetId}, "metadata did not match expected") - selector = NewSelector("source=" + otherSource) + selector = Selector{}.WithSource(otherSource) _, metadata, _ = store.Get(context.Background(), "key", &selector) assert.Equal(t, metadata, model.Metadata{"source": otherSource}, "metadata did not match expected") } diff --git a/flagd/pkg/service/flag-evaluation/context_utils.go b/flagd/pkg/service/flag-evaluation/context_utils.go index 72d7c00ad..710fab55d 100644 --- a/flagd/pkg/service/flag-evaluation/context_utils.go +++ b/flagd/pkg/service/flag-evaluation/context_utils.go @@ -2,8 +2,21 @@ package service import ( "net/http" + + "connectrpc.com/connect" + flagdService "github.com/open-feature/flagd/flagd/pkg/service" + "github.com/open-feature/flagd/core/pkg/store" ) +func selectorFromHeader(header http.Header) (store.Selector, error) { + expr := header.Get(flagdService.FLAGD_SELECTOR_HEADER) + s, err := store.NewSelector(expr) + if err != nil { + return store.Selector{}, connect.NewError(connect.CodeInvalidArgument, err) + } + return s, nil +} + // MergeContextsAndHeaders merges evaluation contexts with static context values and header-based context. // highest priority > header-context-from-cli > static-context-from-cli > request-context > lowest priority // Header names are matched case-insensitively according to HTTP specification. diff --git a/flagd/pkg/service/flag-evaluation/flag_evaluator.go b/flagd/pkg/service/flag-evaluation/flag_evaluator.go index ff376ebfb..3826c2f34 100644 --- a/flagd/pkg/service/flag-evaluation/flag_evaluator.go +++ b/flagd/pkg/service/flag-evaluation/flag_evaluator.go @@ -14,7 +14,6 @@ import ( "github.com/open-feature/flagd/core/pkg/service" "github.com/open-feature/flagd/core/pkg/store" "github.com/open-feature/flagd/core/pkg/telemetry" - flagdService "github.com/open-feature/flagd/flagd/pkg/service" "github.com/rs/xid" "go.opentelemetry.io/otel" "go.opentelemetry.io/otel/attribute" @@ -75,8 +74,10 @@ func (s *OldFlagEvaluationService) ResolveAll( Flags: make(map[string]*schemaV1.AnyFlag), } - selectorExpression := req.Header().Get(flagdService.FLAGD_SELECTOR_HEADER) - selector := store.NewSelector(selectorExpression) + selector, err := selectorFromHeader(req.Header()) + if err != nil { + return nil, err + } ctx = context.WithValue(ctx, store.SelectorContextKey{}, selector) values, _, err := s.eval.ResolveAllValues(ctx, reqID, mergeContexts(req.Msg.GetContext().AsMap(), s.contextValues, req.Header(), make(map[string]string))) @@ -142,8 +143,10 @@ func (s *OldFlagEvaluationService) EventStream( s.logger.Debug(fmt.Sprintf("starting event stream for request")) requestNotificationChan := make(chan service.Notification, 1) - selectorExpression := req.Header().Get(flagdService.FLAGD_SELECTOR_HEADER) - selector := store.NewSelector(selectorExpression) + selector, err := selectorFromHeader(req.Header()) + if err != nil { + return err + } s.eventingConfiguration.Subscribe(ctx, req, &selector, requestNotificationChan) defer s.eventingConfiguration.Unsubscribe(req) @@ -185,11 +188,13 @@ func (s *OldFlagEvaluationService) ResolveBoolean( ctx, span := s.flagEvalTracer.Start(ctx, "resolveBoolean", trace.WithSpanKind(trace.SpanKindServer)) defer span.End() res := connect.NewResponse(&schemaV1.ResolveBooleanResponse{}) - selectorExpression := req.Header().Get(flagdService.FLAGD_SELECTOR_HEADER) - selector := store.NewSelector(selectorExpression) + selector, err := selectorFromHeader(req.Header()) + if err != nil { + return nil, err + } ctx = context.WithValue(ctx, store.SelectorContextKey{}, selector) - err := resolve[bool]( + err = resolve[bool]( ctx, s.logger, s.eval.ResolveBooleanValue, @@ -217,12 +222,14 @@ func (s *OldFlagEvaluationService) ResolveString( ctx, span := s.flagEvalTracer.Start(ctx, "resolveString", trace.WithSpanKind(trace.SpanKindServer)) defer span.End() - selectorExpression := req.Header().Get(flagdService.FLAGD_SELECTOR_HEADER) - selector := store.NewSelector(selectorExpression) + selector, err := selectorFromHeader(req.Header()) + if err != nil { + return nil, err + } ctx = context.WithValue(ctx, store.SelectorContextKey{}, selector) res := connect.NewResponse(&schemaV1.ResolveStringResponse{}) - err := resolve[string]( + err = resolve[string]( ctx, s.logger, s.eval.ResolveStringValue, @@ -250,12 +257,14 @@ func (s *OldFlagEvaluationService) ResolveInt( ctx, span := s.flagEvalTracer.Start(ctx, "resolveInt", trace.WithSpanKind(trace.SpanKindServer)) defer span.End() - selectorExpression := req.Header().Get(flagdService.FLAGD_SELECTOR_HEADER) - selector := store.NewSelector(selectorExpression) + selector, err := selectorFromHeader(req.Header()) + if err != nil { + return nil, err + } ctx = context.WithValue(ctx, store.SelectorContextKey{}, selector) res := connect.NewResponse(&schemaV1.ResolveIntResponse{}) - err := resolve[int64]( + err = resolve[int64]( ctx, s.logger, s.eval.ResolveIntValue, @@ -283,12 +292,14 @@ func (s *OldFlagEvaluationService) ResolveFloat( ctx, span := s.flagEvalTracer.Start(ctx, "resolveFloat", trace.WithSpanKind(trace.SpanKindServer)) defer span.End() - selectorExpression := req.Header().Get(flagdService.FLAGD_SELECTOR_HEADER) - selector := store.NewSelector(selectorExpression) + selector, err := selectorFromHeader(req.Header()) + if err != nil { + return nil, err + } ctx = context.WithValue(ctx, store.SelectorContextKey{}, selector) res := connect.NewResponse(&schemaV1.ResolveFloatResponse{}) - err := resolve[float64]( + err = resolve[float64]( ctx, s.logger, s.eval.ResolveFloatValue, @@ -316,12 +327,14 @@ func (s *OldFlagEvaluationService) ResolveObject( ctx, span := s.flagEvalTracer.Start(ctx, "resolveObject", trace.WithSpanKind(trace.SpanKindServer)) defer span.End() - selectorExpression := req.Header().Get(flagdService.FLAGD_SELECTOR_HEADER) - selector := store.NewSelector(selectorExpression) + selector, err := selectorFromHeader(req.Header()) + if err != nil { + return nil, err + } ctx = context.WithValue(ctx, store.SelectorContextKey{}, selector) res := connect.NewResponse(&schemaV1.ResolveObjectResponse{}) - err := resolve[map[string]any]( + err = resolve[map[string]any]( ctx, s.logger, s.eval.ResolveObjectValue, diff --git a/flagd/pkg/service/flag-evaluation/flag_evaluator_test.go b/flagd/pkg/service/flag-evaluation/flag_evaluator_test.go index de3df1362..d7e45797c 100644 --- a/flagd/pkg/service/flag-evaluation/flag_evaluator_test.go +++ b/flagd/pkg/service/flag-evaluation/flag_evaluator_test.go @@ -1046,3 +1046,69 @@ func Test_Readable_ErrorMessage(t *testing.T) { }) } } + +func TestInvalidSelector_OldFlagEvaluationService(t *testing.T) { + ctrl := gomock.NewController(t) + eval := mock.NewMockIEvaluator(ctrl) + metrics, _ := getMetricReader() + s := NewOldFlagEvaluationService(logger.NewLogger(nil, false), eval, &eventingConfiguration{}, metrics, nil) + + assertInvalidSelectorConnect(t, []invalidSelectorCase{ + {"ResolveAll", func() error { + req := connect.NewRequest(&schemaV1.ResolveAllRequest{}) + req.Header().Set("Flagd-Selector", invalidSelectorExpr) + _, err := s.ResolveAll(context.Background(), req) + return err + }}, + {"ResolveBoolean", func() error { + req := connect.NewRequest(&schemaV1.ResolveBooleanRequest{FlagKey: "f"}) + req.Header().Set("Flagd-Selector", invalidSelectorExpr) + _, err := s.ResolveBoolean(context.Background(), req) + return err + }}, + {"ResolveString", func() error { + req := connect.NewRequest(&schemaV1.ResolveStringRequest{FlagKey: "f"}) + req.Header().Set("Flagd-Selector", invalidSelectorExpr) + _, err := s.ResolveString(context.Background(), req) + return err + }}, + {"ResolveInt", func() error { + req := connect.NewRequest(&schemaV1.ResolveIntRequest{FlagKey: "f"}) + req.Header().Set("Flagd-Selector", invalidSelectorExpr) + _, err := s.ResolveInt(context.Background(), req) + return err + }}, + {"ResolveFloat", func() error { + req := connect.NewRequest(&schemaV1.ResolveFloatRequest{FlagKey: "f"}) + req.Header().Set("Flagd-Selector", invalidSelectorExpr) + _, err := s.ResolveFloat(context.Background(), req) + return err + }}, + {"ResolveObject", func() error { + req := connect.NewRequest(&schemaV1.ResolveObjectRequest{FlagKey: "f"}) + req.Header().Set("Flagd-Selector", invalidSelectorExpr) + _, err := s.ResolveObject(context.Background(), req) + return err + }}, + }) +} + +const invalidSelectorExpr = "invalidKey=val" + +type invalidSelectorCase struct { + name string + call func() error +} + +func assertInvalidSelectorConnect(t *testing.T, cases []invalidSelectorCase) { + t.Helper() + for _, tt := range cases { + t.Run(tt.name, func(t *testing.T) { + err := tt.call() + require.Error(t, err) + var connectErr *connect.Error + require.True(t, errors.As(err, &connectErr)) + require.Equal(t, connect.CodeInvalidArgument, connectErr.Code()) + }) + } +} diff --git a/flagd/pkg/service/flag-evaluation/flag_evaluator_v1.go b/flagd/pkg/service/flag-evaluation/flag_evaluator_v1.go index 747a8742b..8182e5ebb 100644 --- a/flagd/pkg/service/flag-evaluation/flag_evaluator_v1.go +++ b/flagd/pkg/service/flag-evaluation/flag_evaluator_v1.go @@ -13,7 +13,6 @@ import ( "github.com/open-feature/flagd/core/pkg/service" "github.com/open-feature/flagd/core/pkg/store" "github.com/open-feature/flagd/core/pkg/telemetry" - flagdService "github.com/open-feature/flagd/flagd/pkg/service" "github.com/rs/xid" "go.opentelemetry.io/otel" "go.opentelemetry.io/otel/attribute" @@ -75,8 +74,10 @@ func (s *FlagEvaluationService) ResolveAll( Flags: make(map[string]*evalV1.AnyFlag), } - selectorExpression := req.Header().Get(flagdService.FLAGD_SELECTOR_HEADER) - selector := store.NewSelector(selectorExpression) + selector, err := selectorFromHeader(req.Header()) + if err != nil { + return nil, err + } evaluationContext := mergeContexts(req.Msg.GetContext().AsMap(), s.contextValues, req.Header(), s.headerToContextKeyMappings) ctx = context.WithValue(ctx, store.SelectorContextKey{}, selector) ctx = context.WithValue(ctx, evaluator.ProtoVersionKey, "v1") @@ -165,8 +166,10 @@ func (s *FlagEvaluationService) EventStream( s.logger.Debug("starting event stream for request") requestNotificationChan := make(chan service.Notification, 1) - selectorExpression := req.Header().Get(flagdService.FLAGD_SELECTOR_HEADER) - selector := store.NewSelector(selectorExpression) + selector, err := selectorFromHeader(req.Header()) + if err != nil { + return err + } s.eventingConfiguration.Subscribe(ctx, req, &selector, requestNotificationChan) defer s.eventingConfiguration.Unsubscribe(req) @@ -211,13 +214,15 @@ func (s *FlagEvaluationService) ResolveBoolean( ctx, span := s.flagEvalTracer.Start(ctx, "resolveBoolean", trace.WithSpanKind(trace.SpanKindServer)) defer span.End() - selectorExpression := req.Header().Get(flagdService.FLAGD_SELECTOR_HEADER) - selector := store.NewSelector(selectorExpression) + selector, err := selectorFromHeader(req.Header()) + if err != nil { + return nil, err + } ctx = context.WithValue(ctx, store.SelectorContextKey{}, selector) ctx = context.WithValue(ctx, evaluator.ProtoVersionKey, "v1") res := connect.NewResponse(&evalV1.ResolveBooleanResponse{}) - err := resolve( + err = resolve( ctx, s.logger, s.eval.ResolveBooleanValue, @@ -244,13 +249,15 @@ func (s *FlagEvaluationService) ResolveString( ctx, span := s.flagEvalTracer.Start(ctx, "resolveString", trace.WithSpanKind(trace.SpanKindServer)) defer span.End() - selectorExpression := req.Header().Get(flagdService.FLAGD_SELECTOR_HEADER) - selector := store.NewSelector(selectorExpression) + selector, err := selectorFromHeader(req.Header()) + if err != nil { + return nil, err + } ctx = context.WithValue(ctx, store.SelectorContextKey{}, selector) ctx = context.WithValue(ctx, evaluator.ProtoVersionKey, "v1") res := connect.NewResponse(&evalV1.ResolveStringResponse{}) - err := resolve( + err = resolve( ctx, s.logger, s.eval.ResolveStringValue, @@ -277,13 +284,15 @@ func (s *FlagEvaluationService) ResolveInt( ctx, span := s.flagEvalTracer.Start(ctx, "resolveInt", trace.WithSpanKind(trace.SpanKindServer)) defer span.End() - selectorExpression := req.Header().Get(flagdService.FLAGD_SELECTOR_HEADER) - selector := store.NewSelector(selectorExpression) + selector, err := selectorFromHeader(req.Header()) + if err != nil { + return nil, err + } ctx = context.WithValue(ctx, store.SelectorContextKey{}, selector) ctx = context.WithValue(ctx, evaluator.ProtoVersionKey, "v1") res := connect.NewResponse(&evalV1.ResolveIntResponse{}) - err := resolve( + err = resolve( ctx, s.logger, s.eval.ResolveIntValue, @@ -310,13 +319,15 @@ func (s *FlagEvaluationService) ResolveFloat( ctx, span := s.flagEvalTracer.Start(ctx, "resolveFloat", trace.WithSpanKind(trace.SpanKindServer)) defer span.End() - selectorExpression := req.Header().Get(flagdService.FLAGD_SELECTOR_HEADER) - selector := store.NewSelector(selectorExpression) + selector, err := selectorFromHeader(req.Header()) + if err != nil { + return nil, err + } ctx = context.WithValue(ctx, store.SelectorContextKey{}, selector) ctx = context.WithValue(ctx, evaluator.ProtoVersionKey, "v1") res := connect.NewResponse(&evalV1.ResolveFloatResponse{}) - err := resolve( + err = resolve( ctx, s.logger, s.eval.ResolveFloatValue, @@ -343,13 +354,15 @@ func (s *FlagEvaluationService) ResolveObject( ctx, span := s.flagEvalTracer.Start(ctx, "resolveObject", trace.WithSpanKind(trace.SpanKindServer)) defer span.End() - selectorExpression := req.Header().Get(flagdService.FLAGD_SELECTOR_HEADER) - selector := store.NewSelector(selectorExpression) + selector, err := selectorFromHeader(req.Header()) + if err != nil { + return nil, err + } ctx = context.WithValue(ctx, store.SelectorContextKey{}, selector) ctx = context.WithValue(ctx, evaluator.ProtoVersionKey, "v1") res := connect.NewResponse(&evalV1.ResolveObjectResponse{}) - err := resolve( + err = resolve( ctx, s.logger, s.eval.ResolveObjectValue, diff --git a/flagd/pkg/service/flag-evaluation/flag_evaluator_v1_test.go b/flagd/pkg/service/flag-evaluation/flag_evaluator_v1_test.go index 70dbf3bf6..2c6fd2fa6 100644 --- a/flagd/pkg/service/flag-evaluation/flag_evaluator_v1_test.go +++ b/flagd/pkg/service/flag-evaluation/flag_evaluator_v1_test.go @@ -1067,3 +1067,49 @@ func Test_mergeContexts(t *testing.T) { }) } } + +func TestInvalidSelector_FlagEvaluationService(t *testing.T) { + ctrl := gomock.NewController(t) + eval := mock.NewMockIEvaluator(ctrl) + metrics, _ := getMetricReader() + s := NewFlagEvaluationService(logger.NewLogger(nil, false), eval, &eventingConfiguration{}, metrics, nil, nil, 0) + + assertInvalidSelectorConnect(t, []invalidSelectorCase{ + {"ResolveAll", func() error { + req := connect.NewRequest(&evalV1.ResolveAllRequest{}) + req.Header().Set("Flagd-Selector", invalidSelectorExpr) + _, err := s.ResolveAll(context.Background(), req) + return err + }}, + {"ResolveBoolean", func() error { + req := connect.NewRequest(&evalV1.ResolveBooleanRequest{FlagKey: "f"}) + req.Header().Set("Flagd-Selector", invalidSelectorExpr) + _, err := s.ResolveBoolean(context.Background(), req) + return err + }}, + {"ResolveString", func() error { + req := connect.NewRequest(&evalV1.ResolveStringRequest{FlagKey: "f"}) + req.Header().Set("Flagd-Selector", invalidSelectorExpr) + _, err := s.ResolveString(context.Background(), req) + return err + }}, + {"ResolveInt", func() error { + req := connect.NewRequest(&evalV1.ResolveIntRequest{FlagKey: "f"}) + req.Header().Set("Flagd-Selector", invalidSelectorExpr) + _, err := s.ResolveInt(context.Background(), req) + return err + }}, + {"ResolveFloat", func() error { + req := connect.NewRequest(&evalV1.ResolveFloatRequest{FlagKey: "f"}) + req.Header().Set("Flagd-Selector", invalidSelectorExpr) + _, err := s.ResolveFloat(context.Background(), req) + return err + }}, + {"ResolveObject", func() error { + req := connect.NewRequest(&evalV1.ResolveObjectRequest{FlagKey: "f"}) + req.Header().Set("Flagd-Selector", invalidSelectorExpr) + _, err := s.ResolveObject(context.Background(), req) + return err + }}, + }) +} diff --git a/flagd/pkg/service/flag-evaluation/flag_evaluator_v2.go b/flagd/pkg/service/flag-evaluation/flag_evaluator_v2.go index a78650776..dbb43830e 100644 --- a/flagd/pkg/service/flag-evaluation/flag_evaluator_v2.go +++ b/flagd/pkg/service/flag-evaluation/flag_evaluator_v2.go @@ -15,7 +15,6 @@ import ( "github.com/open-feature/flagd/core/pkg/service" "github.com/open-feature/flagd/core/pkg/store" "github.com/open-feature/flagd/core/pkg/telemetry" - flagdService "github.com/open-feature/flagd/flagd/pkg/service" "github.com/rs/xid" "go.opentelemetry.io/otel" "go.opentelemetry.io/otel/codes" @@ -80,8 +79,10 @@ func (s *FlagEvaluationServiceV2) EventStream( s.logger.Debug("starting event stream for request") requestNotificationChan := make(chan service.Notification, 1) - selectorExpression := req.Header().Get(flagdService.FLAGD_SELECTOR_HEADER) - selector := store.NewSelector(selectorExpression) + selector, err := selectorFromHeader(req.Header()) + if err != nil { + return err + } s.eventingConfiguration.Subscribe(ctx, req, &selector, requestNotificationChan) defer s.eventingConfiguration.Unsubscribe(req) @@ -123,11 +124,14 @@ func (s *FlagEvaluationServiceV2) ResolveBoolean( ctx context.Context, req *connect.Request[evalV2.ResolveBooleanRequest], ) (*connect.Response[evalV2.ResolveBooleanResponse], error) { - ctx, span := s.startResolveV2(ctx, "resolveBoolean", req.Header()) + ctx, span, err := s.startResolveV2(ctx, "resolveBoolean", req.Header()) defer span.End() + if err != nil { + return nil, err + } res := connect.NewResponse(&evalV2.ResolveBooleanResponse{}) - err := resolveV2( + err = resolveV2( ctx, s.logger, s.eval.ResolveBooleanValue, req.Header(), req.Msg.GetFlagKey(), req.Msg.GetContext(), &booleanResponseV2{evalV2Resp: res}, @@ -142,11 +146,14 @@ func (s *FlagEvaluationServiceV2) ResolveString( ctx context.Context, req *connect.Request[evalV2.ResolveStringRequest], ) (*connect.Response[evalV2.ResolveStringResponse], error) { - ctx, span := s.startResolveV2(ctx, "resolveString", req.Header()) + ctx, span, err := s.startResolveV2(ctx, "resolveString", req.Header()) defer span.End() + if err != nil { + return nil, err + } res := connect.NewResponse(&evalV2.ResolveStringResponse{}) - err := resolveV2( + err = resolveV2( ctx, s.logger, s.eval.ResolveStringValue, req.Header(), req.Msg.GetFlagKey(), req.Msg.GetContext(), &stringResponseV2{evalV2Resp: res}, @@ -161,11 +168,14 @@ func (s *FlagEvaluationServiceV2) ResolveInt( ctx context.Context, req *connect.Request[evalV2.ResolveIntRequest], ) (*connect.Response[evalV2.ResolveIntResponse], error) { - ctx, span := s.startResolveV2(ctx, "resolveInt", req.Header()) + ctx, span, err := s.startResolveV2(ctx, "resolveInt", req.Header()) defer span.End() + if err != nil { + return nil, err + } res := connect.NewResponse(&evalV2.ResolveIntResponse{}) - err := resolveV2( + err = resolveV2( ctx, s.logger, s.eval.ResolveIntValue, req.Header(), req.Msg.GetFlagKey(), req.Msg.GetContext(), &intResponseV2{evalV2Resp: res}, @@ -180,11 +190,14 @@ func (s *FlagEvaluationServiceV2) ResolveFloat( ctx context.Context, req *connect.Request[evalV2.ResolveFloatRequest], ) (*connect.Response[evalV2.ResolveFloatResponse], error) { - ctx, span := s.startResolveV2(ctx, "resolveFloat", req.Header()) + ctx, span, err := s.startResolveV2(ctx, "resolveFloat", req.Header()) defer span.End() + if err != nil { + return nil, err + } res := connect.NewResponse(&evalV2.ResolveFloatResponse{}) - err := resolveV2( + err = resolveV2( ctx, s.logger, s.eval.ResolveFloatValue, req.Header(), req.Msg.GetFlagKey(), req.Msg.GetContext(), &floatResponseV2{evalV2Resp: res}, @@ -199,11 +212,14 @@ func (s *FlagEvaluationServiceV2) ResolveObject( ctx context.Context, req *connect.Request[evalV2.ResolveObjectRequest], ) (*connect.Response[evalV2.ResolveObjectResponse], error) { - ctx, span := s.startResolveV2(ctx, "resolveObject", req.Header()) + ctx, span, err := s.startResolveV2(ctx, "resolveObject", req.Header()) defer span.End() + if err != nil { + return nil, err + } res := connect.NewResponse(&evalV2.ResolveObjectResponse{}) - err := resolveV2( + err = resolveV2( ctx, s.logger, s.eval.ResolveObjectValue, req.Header(), req.Msg.GetFlagKey(), req.Msg.GetContext(), &objectResponseV2{evalV2Resp: res}, @@ -268,14 +284,18 @@ func resolveV2[T constraints](ctx context.Context, logger *logger.Logger, resolv // startResolveV2 initialises tracing and selector context common to every Resolve* method. func (s *FlagEvaluationServiceV2) startResolveV2( ctx context.Context, spanName string, header http.Header, -) (context.Context, trace.Span) { +) (context.Context, trace.Span, error) { ctx, span := s.flagEvalTracer.Start(ctx, spanName, trace.WithSpanKind(trace.SpanKindServer)) - selectorExpression := header.Get(flagdService.FLAGD_SELECTOR_HEADER) - selector := store.NewSelector(selectorExpression) + selector, err := selectorFromHeader(header) + if err != nil { + span.RecordError(err) + span.SetStatus(codes.Error, err.Error()) + return ctx, span, err + } ctx = context.WithValue(ctx, store.SelectorContextKey{}, selector) - return ctx, span + return ctx, span, nil } // recordResolveErrorV2 records an evaluation error on the active span. diff --git a/flagd/pkg/service/flag-evaluation/flag_evaluator_v2_test.go b/flagd/pkg/service/flag-evaluation/flag_evaluator_v2_test.go index c64699a62..4097cb7a0 100644 --- a/flagd/pkg/service/flag-evaluation/flag_evaluator_v2_test.go +++ b/flagd/pkg/service/flag-evaluation/flag_evaluator_v2_test.go @@ -121,3 +121,43 @@ func TestFlagEvaluationServiceV2_Fallback(t *testing.T) { }) } } + +func TestInvalidSelector_FlagEvaluationServiceV2(t *testing.T) { + ctrl := gomock.NewController(t) + eval := mock.NewMockIEvaluator(ctrl) + metrics, _ := getMetricReader() + s := NewFlagEvaluationServiceV2(logger.NewLogger(nil, false), eval, &eventingConfiguration{}, metrics, nil, nil, 0) + + assertInvalidSelectorConnect(t, []invalidSelectorCase{ + {"ResolveBoolean", func() error { + req := connect.NewRequest(&evalV2.ResolveBooleanRequest{FlagKey: "f"}) + req.Header().Set("Flagd-Selector", invalidSelectorExpr) + _, err := s.ResolveBoolean(context.Background(), req) + return err + }}, + {"ResolveString", func() error { + req := connect.NewRequest(&evalV2.ResolveStringRequest{FlagKey: "f"}) + req.Header().Set("Flagd-Selector", invalidSelectorExpr) + _, err := s.ResolveString(context.Background(), req) + return err + }}, + {"ResolveInt", func() error { + req := connect.NewRequest(&evalV2.ResolveIntRequest{FlagKey: "f"}) + req.Header().Set("Flagd-Selector", invalidSelectorExpr) + _, err := s.ResolveInt(context.Background(), req) + return err + }}, + {"ResolveFloat", func() error { + req := connect.NewRequest(&evalV2.ResolveFloatRequest{FlagKey: "f"}) + req.Header().Set("Flagd-Selector", invalidSelectorExpr) + _, err := s.ResolveFloat(context.Background(), req) + return err + }}, + {"ResolveObject", func() error { + req := connect.NewRequest(&evalV2.ResolveObjectRequest{FlagKey: "f"}) + req.Header().Set("Flagd-Selector", invalidSelectorExpr) + _, err := s.ResolveObject(context.Background(), req) + return err + }}, + }) +} diff --git a/flagd/pkg/service/flag-evaluation/ofrep/handler.go b/flagd/pkg/service/flag-evaluation/ofrep/handler.go index ac3f2ddf8..bfeb9ea5a 100644 --- a/flagd/pkg/service/flag-evaluation/ofrep/handler.go +++ b/flagd/pkg/service/flag-evaluation/ofrep/handler.go @@ -97,7 +97,15 @@ func (h *handler) HandleFlagEvaluation(w http.ResponseWriter, r *http.Request) { } evaluationContext := flagdContext(h.Logger, requestID, request, h.contextValues, r.Header, h.headerToContextKeyMappings) selectorExpression := r.Header.Get(service.FLAGD_SELECTOR_HEADER) - selector := store.NewSelector(selectorExpression) + selector, err := store.NewSelector(selectorExpression) + if err != nil { + h.writeJSONToResponse(http.StatusBadRequest, ofrep.EvaluationError{ + Key: flagKey, + ErrorCode: model.GeneralErrorCode, + ErrorDetails: fmt.Sprintf("invalid selector: %v", err), + }, w) + return + } ctx := context.WithValue(r.Context(), store.SelectorContextKey{}, selector) evaluation := h.evaluator.ResolveAsAnyValue(ctx, requestID, flagKey, evaluationContext) @@ -122,7 +130,12 @@ func (h *handler) HandleBulkEvaluation(w http.ResponseWriter, r *http.Request) { evaluationContext := flagdContext(h.Logger, requestID, request, h.contextValues, r.Header, h.headerToContextKeyMappings) selectorExpression := r.Header.Get(service.FLAGD_SELECTOR_HEADER) - selector := store.NewSelector(selectorExpression) + selector, err := store.NewSelector(selectorExpression) + if err != nil { + res := ofrep.BulkEvaluationContextErrorFrom(model.GeneralErrorCode, fmt.Sprintf("invalid selector: %v", err)) + h.writeJSONToResponse(http.StatusBadRequest, res, w) + return + } ctx := context.WithValue(r.Context(), store.SelectorContextKey{}, selector) evaluations, metadata, err := h.evaluator.ResolveAllValues(ctx, requestID, evaluationContext) diff --git a/flagd/pkg/service/flag-evaluation/ofrep/handler_test.go b/flagd/pkg/service/flag-evaluation/ofrep/handler_test.go index 3ae4d4635..b83756820 100644 --- a/flagd/pkg/service/flag-evaluation/ofrep/handler_test.go +++ b/flagd/pkg/service/flag-evaluation/ofrep/handler_test.go @@ -16,6 +16,7 @@ import ( "github.com/open-feature/flagd/core/pkg/logger" "github.com/open-feature/flagd/core/pkg/model" "github.com/open-feature/flagd/core/pkg/service/ofrep" + "github.com/stretchr/testify/require" "go.uber.org/mock/gomock" ) @@ -374,3 +375,38 @@ func TestFlagdContextDelegatesContextMerging(t *testing.T) { } } } + +func TestInvalidSelector_OFREPHandler(t *testing.T) { + const invalidSelector = "invalidKey=val" + log := logger.NewLogger(nil, false) + + t.Run("HandleFlagEvaluation", func(t *testing.T) { + eval := mock.NewMockIEvaluator(gomock.NewController(t)) + h := handler{Logger: log, evaluator: eval} + + req, _ := http.NewRequest(http.MethodPost, "/ofrep/v1/evaluate/flags/"+flagKey, bytes.NewReader([]byte{})) + req.Header.Set("Flagd-Selector", invalidSelector) + + recorder := httptest.NewRecorder() + router := mux.NewRouter() + router.HandleFunc(singleEvaluation, h.HandleFlagEvaluation) + router.ServeHTTP(recorder, req) + + require.Equal(t, http.StatusBadRequest, recorder.Code) + }) + + t.Run("HandleBulkEvaluation", func(t *testing.T) { + eval := mock.NewMockIEvaluator(gomock.NewController(t)) + h := handler{Logger: log, evaluator: eval} + + req, _ := http.NewRequest(http.MethodPost, "/ofrep/v1/evaluate/flags", bytes.NewReader([]byte{})) + req.Header.Set("Flagd-Selector", invalidSelector) + + recorder := httptest.NewRecorder() + router := mux.NewRouter() + router.HandleFunc(bulkEvaluation, h.HandleBulkEvaluation) + router.ServeHTTP(recorder, req) + + require.Equal(t, http.StatusBadRequest, recorder.Code) + }) +} diff --git a/flagd/pkg/service/flag-sync/handler.go b/flagd/pkg/service/flag-sync/handler.go index 559a33f44..d23d63903 100644 --- a/flagd/pkg/service/flag-sync/handler.go +++ b/flagd/pkg/service/flag-sync/handler.go @@ -62,7 +62,11 @@ func (s syncHandler) SyncFlags(req *syncv1.SyncFlagsRequest, server syncv1grpc.F }() watcher := make(chan store.FlagQueryResult, 1) - selector := store.NewSelector(selectorExpression) + selector, err := store.NewSelector(selectorExpression) + if err != nil { + exitReason = "error" + return status.Error(codes.InvalidArgument, err.Error()) + } ctx := server.Context() syncContextMap := make(map[string]any) @@ -166,7 +170,10 @@ func (s syncHandler) FetchAllFlags(ctx context.Context, req *syncv1.FetchAllFlag *syncv1.FetchAllFlagsResponse, error, ) { selectorExpression := s.getSelectorExpression(ctx, req) - selector := store.NewSelector(selectorExpression) + selector, err := store.NewSelector(selectorExpression) + if err != nil { + return nil, status.Error(codes.InvalidArgument, err.Error()) + } flags, _, err := s.store.GetAll(ctx, &selector) if err != nil { s.log.Error(fmt.Sprintf("error retrieving flags from store: %v", err)) diff --git a/flagd/pkg/service/flag-sync/handler_test.go b/flagd/pkg/service/flag-sync/handler_test.go index 7621be845..b869e409b 100644 --- a/flagd/pkg/service/flag-sync/handler_test.go +++ b/flagd/pkg/service/flag-sync/handler_test.go @@ -258,3 +258,35 @@ func TestSyncHandler_SelectorLocationPrecedence(t *testing.T) { }) } } + +func TestSyncHandler_InvalidSelector(t *testing.T) { + const invalidSelector = "invalidKey=val" + + flagStore, err := store.NewStore(logger.NewLogger(nil, false), []string{}) + require.NoError(t, err) + + h := syncHandler{ + store: flagStore, + log: logger.NewLogger(nil, false), + contextValues: map[string]any{}, + metricsRecorder: &telemetry.NoopMetricsRecorder{}, + } + + ctxWithInvalidSelector := metadata.NewIncomingContext( + context.Background(), + metadata.New(map[string]string{flagdService.FLAGD_SELECTOR_HEADER: invalidSelector}), + ) + + t.Run("SyncFlags", func(t *testing.T) { + stream := &mockSyncFlagsServer{ctx: ctxWithInvalidSelector, respReady: make(chan struct{}, 1)} + err := h.SyncFlags(&syncv1.SyncFlagsRequest{}, stream) + require.Error(t, err) + require.Equal(t, "rpc error: code = InvalidArgument desc = invalid selector key \"invalidKey\", valid keys: \"flagSetId\", \"source\"", err.Error()) + }) + + t.Run("FetchAllFlags", func(t *testing.T) { + _, err := h.FetchAllFlags(ctxWithInvalidSelector, &syncv1.FetchAllFlagsRequest{}) + require.Error(t, err) + require.Equal(t, "rpc error: code = InvalidArgument desc = invalid selector key \"invalidKey\", valid keys: \"flagSetId\", \"source\"", err.Error()) + }) +}