From 0519beb1d784c9f381dcbc744bd57fc25618d316 Mon Sep 17 00:00:00 2001 From: Alexandre Chakroun Date: Mon, 1 Jun 2026 10:32:13 +0200 Subject: [PATCH 1/6] Prevent usage of unsupported selectors to avoid memdb errors Signed-off-by: Alexandre Chakroun --- core/pkg/store/query.go | 41 +++++++++++++--- core/pkg/store/query_test.go | 25 ++++++++-- core/pkg/store/store.go | 6 +-- core/pkg/store/store_test.go | 24 ++++----- .../service/flag-evaluation/flag_evaluator.go | 35 ++++++++++--- .../flag-evaluation/flag_evaluator_v1.go | 35 ++++++++++--- .../flag-evaluation/flag_evaluator_v2.go | 49 +++++++++++++------ .../service/flag-evaluation/ofrep/handler.go | 17 ++++++- flagd/pkg/service/flag-sync/handler.go | 10 +++- 9 files changed, 184 insertions(+), 58 deletions(-) diff --git a/core/pkg/store/query.go b/core/pkg/store/query.go index 478440a8b..8f3d5d781 100644 --- a/core/pkg/store/query.go +++ b/core/pkg/store/query.go @@ -33,14 +33,33 @@ type Selector struct { indexMap map[string]string } +var validSelectorKeys = map[string]struct{}{ + flagSetIdIndex: {}, + sourceIndex: {}, +} + // NewSelector creates a new Selector from a selector expression string. // #1708 Until we decide on the Selector syntax, only a single key=value pair is supported // For example, to select flags from source "./mySource" or flagSetId "1234", use the expressions: // "source=./mySource" or "flagSetId=1234" -func NewSelector(selectorExpression string) Selector { - return Selector{ - indexMap: expressionToMap(selectorExpression), +func NewSelector(selectorExpression string) (Selector, error) { + m := expressionToMap(selectorExpression) + for key := range m { + if _, ok := validSelectorKeys[key]; !ok { + return Selector{}, fmt.Errorf("invalid selector key %q, valid keys: %q, %q", key, flagSetIdIndex, sourceIndex) + } } + return Selector{indexMap: m}, nil +} + +// NewSourceSelector creates a Selector that queries by source. +func NewSourceSelector(source string) Selector { + return Selector{indexMap: map[string]string{sourceIndex: source}} +} + +// NewFlagSetIdSelector creates a Selector that queries by flagSetId. +func NewFlagSetIdSelector(flagSetId string) Selector { + return Selector{indexMap: map[string]string{flagSetIdIndex: flagSetId}} } func expressionToMap(sExp string) map[string]string { @@ -70,13 +89,19 @@ func expressionToMap(sExp string) map[string]string { return selectorMap } -// WithIndex creates a new Selector from the current Selector and adds the given key-value-pair -func (s Selector) WithIndex(key string, value string) Selector { +// WithIndex returns a new Selector with the given key-value pair added. +// Returns an error if the key is not a valid user-facing selector key. +func (s Selector) WithIndex(key string, value string) (Selector, error) { + if _, ok := validSelectorKeys[key]; !ok { + return s, fmt.Errorf("invalid selector key %q, valid keys: %q, %q", key, flagSetIdIndex, sourceIndex) + } + return s.withIndex(key, value), nil +} + +func (s Selector) withIndex(key string, value string) Selector { m := maps.Clone(s.indexMap) m[key] = value - return Selector{ - indexMap: m, - } + return Selector{indexMap: m} } func (s *Selector) IsEmpty() bool { diff --git a/core/pkg/store/query_test.go b/core/pkg/store/query_test.go index 042b3248b..f141b172c 100644 --- a/core/pkg/store/query_test.go +++ b/core/pkg/store/query_test.go @@ -52,7 +52,10 @@ func TestSelector_IsEmpty(t *testing.T) { func TestSelector_WithIndex(t *testing.T) { oldS := Selector{indexMap: map[string]string{"source": "abc"}} - newS := oldS.WithIndex("flagSetId", "1234") + newS, err := oldS.WithIndex("flagSetId", "1234") + if err != nil { + t.Fatalf("WithIndex returned unexpected error: %v", err) + } if newS.indexMap["source"] != "abc" { t.Errorf("WithIndex did not preserve existing keys") @@ -64,8 +67,14 @@ func TestSelector_WithIndex(t *testing.T) { if _, ok := oldS.indexMap["flagSetId"]; ok { t.Errorf("WithIndex mutated original selector") } + + _, err = oldS.WithIndex("invalidKey", "val") + if err == nil { + t.Errorf("WithIndex should return error for invalid key") + } } + func TestSelector_ToQuery(t *testing.T) { tests := []struct { name string @@ -176,6 +185,7 @@ func TestNewSelector(t *testing.T) { name string input string wantMap map[string]string + wantErr bool }{ // #1708 Until we decide on the Selector syntax, only a single key=value pair is supported /* @@ -205,12 +215,21 @@ func TestNewSelector(t *testing.T) { input: "", wantMap: map[string]string{}, }, + { + name: "invalid key", + input: "flagSetIds=abc", + wantErr: true, + }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - s := NewSelector(tt.input) - if !reflect.DeepEqual(s.indexMap, tt.wantMap) { + s, err := NewSelector(tt.input) + if (err != nil) != tt.wantErr { + t.Errorf("NewSelector(%q) error = %v, wantErr %v", tt.input, err, tt.wantErr) + return + } + if !tt.wantErr && !reflect.DeepEqual(s.indexMap, tt.wantMap) { t.Errorf("NewSelector(%q) indexMap = %v, want %v", tt.input, s.indexMap, tt.wantMap) } }) diff --git a/core/pkg/store/store.go b/core/pkg/store/store.go index 48b80b107..e28fc1ef6 100644 --- a/core/pkg/store/store.go +++ b/core/pkg/store/store.go @@ -153,7 +153,7 @@ func (s *Store) Get(_ context.Context, key string, selector *Selector) (model.Fl // if present, use the selector to query the flags if !selector.IsEmpty() { - selector := selector.WithIndex("key", key) + selector := selector.withIndex("key", key) indexId, constraints := selector.ToQuery() s.logger.Debug(fmt.Sprintf("getting flag with query: %s, %v", indexId, constraints)) raw, err := txn.First(flagsTable, indexId, constraints...) @@ -274,7 +274,7 @@ func (s *Store) Update( seenFlagSetIds[id.flagSetId] = struct{}{} } for fsi := range seenFlagSetIds { - sel := NewSelector(flagSetIdIndex+"="+fsi).WithIndex(sourceIndex, source) + sel := NewFlagSetIdSelector(fsi).withIndex(sourceIndex, source) indexId, constraints := sel.ToQuery() it, err := txn.Get(flagsTable, indexId, constraints...) if err != nil { @@ -284,7 +284,7 @@ func (s *Store) Update( oldFlags = append(oldFlags, s.collect(it)...) } } else { - sel := NewSelector(sourceIndex + "=" + source) + sel := NewSourceSelector(source) indexId, constraints := sel.ToQuery() it, err := txn.Get(flagsTable, indexId, constraints...) if err != nil { diff --git a/core/pkg/store/store_test.go b/core/pkg/store/store_test.go index a729c4e49..bd8631e0f 100644 --- a/core/pkg/store/store_test.go +++ b/core/pkg/store/store_test.go @@ -238,8 +238,8 @@ func TestGet(t *testing.T) { sources := []string{sourceA.Name, sourceB.Name, sourceC.Name} - sourceASelector := NewSelector("source=" + sourceA.Name) - flagSetIdCSelector := NewSelector("flagSetId=" + flagSetIdC) + sourceASelector := NewSourceSelector(sourceA.Name) + flagSetIdCSelector := NewFlagSetIdSelector(flagSetIdC) t.Parallel() tests := []struct { @@ -377,10 +377,10 @@ func TestGetAllNoWatcher(t *testing.T) { sources := []string{sourceA.Name, sourceB.Name, sourceC.Name} - sourceASelector := NewSelector("source=" + sourceA.Name) - flagSetIdCSelector := NewSelector("flagSetId=" + flagSetIdC) + sourceASelector := NewSourceSelector(sourceA.Name) + flagSetIdCSelector := NewFlagSetIdSelector(flagSetIdC) // #1708 Until we decide on the Selector syntax, only a single key=value pair is supported - //flagSetIdAndCSelector := NewSelector("flagSetId=" + flagSetIdC + ",source=" + sourceC.Name) + //flagSetIdAndCSelector := NewFlagSetIdSelector(flagSetIdC).withIndex(sourceIndex, sourceC.Name) t.Parallel() tests := []struct { @@ -506,10 +506,10 @@ func TestWatch(t *testing.T) { pauseTime := 100 * time.Millisecond // time for updates to settle timeout := 1000 * time.Millisecond // time to make sure we get enough updates, and no extras - sourceASelector := NewSelector("source=" + sourceA) - flagSetIdCSelector := NewSelector("flagSetId=" + myFlagSetId) - emptySelector := NewSelector("") - sourceCSelector := NewSelector("source=" + sourceC) + sourceASelector := NewSourceSelector(sourceA) + flagSetIdCSelector := NewFlagSetIdSelector(myFlagSetId) + emptySelector := Selector{} + sourceCSelector := NewSourceSelector(sourceC) tests := []struct { name string @@ -786,15 +786,15 @@ func TestQueryMetadata(t *testing.T) { // #1708 Until we decide on the Selector syntax, only a single key=value pair is supported // these tests should then also cover more complex selectors - selector := NewSelector("flagSetId=" + nonExistingFlagSetId) + selector := NewFlagSetIdSelector(nonExistingFlagSetId) _, metadata, _ := store.GetAll(context.Background(), &selector) assert.Equal(t, metadata, model.Metadata{"flagSetId": nonExistingFlagSetId}, "metadata did not match expected") - selector = NewSelector("flagSetId=" + nonExistingFlagSetId) + selector = NewFlagSetIdSelector(nonExistingFlagSetId) _, metadata, _ = store.Get(context.Background(), "key", &selector) assert.Equal(t, metadata, model.Metadata{"flagSetId": nonExistingFlagSetId}, "metadata did not match expected") - selector = NewSelector("source=" + otherSource) + selector = NewSourceSelector(otherSource) _, metadata, _ = store.Get(context.Background(), "key", &selector) assert.Equal(t, metadata, model.Metadata{"source": otherSource}, "metadata did not match expected") } diff --git a/flagd/pkg/service/flag-evaluation/flag_evaluator.go b/flagd/pkg/service/flag-evaluation/flag_evaluator.go index ff376ebfb..7f4dbf69a 100644 --- a/flagd/pkg/service/flag-evaluation/flag_evaluator.go +++ b/flagd/pkg/service/flag-evaluation/flag_evaluator.go @@ -76,7 +76,10 @@ func (s *OldFlagEvaluationService) ResolveAll( } selectorExpression := req.Header().Get(flagdService.FLAGD_SELECTOR_HEADER) - selector := store.NewSelector(selectorExpression) + selector, err := store.NewSelector(selectorExpression) + if err != nil { + return nil, connect.NewError(connect.CodeInvalidArgument, err) + } ctx = context.WithValue(ctx, store.SelectorContextKey{}, selector) values, _, err := s.eval.ResolveAllValues(ctx, reqID, mergeContexts(req.Msg.GetContext().AsMap(), s.contextValues, req.Header(), make(map[string]string))) @@ -143,7 +146,10 @@ func (s *OldFlagEvaluationService) EventStream( requestNotificationChan := make(chan service.Notification, 1) selectorExpression := req.Header().Get(flagdService.FLAGD_SELECTOR_HEADER) - selector := store.NewSelector(selectorExpression) + selector, err := store.NewSelector(selectorExpression) + if err != nil { + return connect.NewError(connect.CodeInvalidArgument, err) + } s.eventingConfiguration.Subscribe(ctx, req, &selector, requestNotificationChan) defer s.eventingConfiguration.Unsubscribe(req) @@ -186,7 +192,10 @@ func (s *OldFlagEvaluationService) ResolveBoolean( defer span.End() res := connect.NewResponse(&schemaV1.ResolveBooleanResponse{}) selectorExpression := req.Header().Get(flagdService.FLAGD_SELECTOR_HEADER) - selector := store.NewSelector(selectorExpression) + selector, err := store.NewSelector(selectorExpression) + if err != nil { + return nil, connect.NewError(connect.CodeInvalidArgument, err) + } ctx = context.WithValue(ctx, store.SelectorContextKey{}, selector) err := resolve[bool]( @@ -218,7 +227,10 @@ func (s *OldFlagEvaluationService) ResolveString( defer span.End() selectorExpression := req.Header().Get(flagdService.FLAGD_SELECTOR_HEADER) - selector := store.NewSelector(selectorExpression) + selector, err := store.NewSelector(selectorExpression) + if err != nil { + return nil, connect.NewError(connect.CodeInvalidArgument, err) + } ctx = context.WithValue(ctx, store.SelectorContextKey{}, selector) res := connect.NewResponse(&schemaV1.ResolveStringResponse{}) @@ -251,7 +263,10 @@ func (s *OldFlagEvaluationService) ResolveInt( defer span.End() selectorExpression := req.Header().Get(flagdService.FLAGD_SELECTOR_HEADER) - selector := store.NewSelector(selectorExpression) + selector, err := store.NewSelector(selectorExpression) + if err != nil { + return nil, connect.NewError(connect.CodeInvalidArgument, err) + } ctx = context.WithValue(ctx, store.SelectorContextKey{}, selector) res := connect.NewResponse(&schemaV1.ResolveIntResponse{}) @@ -284,7 +299,10 @@ func (s *OldFlagEvaluationService) ResolveFloat( defer span.End() selectorExpression := req.Header().Get(flagdService.FLAGD_SELECTOR_HEADER) - selector := store.NewSelector(selectorExpression) + selector, err := store.NewSelector(selectorExpression) + if err != nil { + return nil, connect.NewError(connect.CodeInvalidArgument, err) + } ctx = context.WithValue(ctx, store.SelectorContextKey{}, selector) res := connect.NewResponse(&schemaV1.ResolveFloatResponse{}) @@ -317,7 +335,10 @@ func (s *OldFlagEvaluationService) ResolveObject( defer span.End() selectorExpression := req.Header().Get(flagdService.FLAGD_SELECTOR_HEADER) - selector := store.NewSelector(selectorExpression) + selector, err := store.NewSelector(selectorExpression) + if err != nil { + return nil, connect.NewError(connect.CodeInvalidArgument, err) + } ctx = context.WithValue(ctx, store.SelectorContextKey{}, selector) res := connect.NewResponse(&schemaV1.ResolveObjectResponse{}) diff --git a/flagd/pkg/service/flag-evaluation/flag_evaluator_v1.go b/flagd/pkg/service/flag-evaluation/flag_evaluator_v1.go index 747a8742b..496122d62 100644 --- a/flagd/pkg/service/flag-evaluation/flag_evaluator_v1.go +++ b/flagd/pkg/service/flag-evaluation/flag_evaluator_v1.go @@ -76,7 +76,10 @@ func (s *FlagEvaluationService) ResolveAll( } selectorExpression := req.Header().Get(flagdService.FLAGD_SELECTOR_HEADER) - selector := store.NewSelector(selectorExpression) + selector, err := store.NewSelector(selectorExpression) + if err != nil { + return nil, connect.NewError(connect.CodeInvalidArgument, err) + } evaluationContext := mergeContexts(req.Msg.GetContext().AsMap(), s.contextValues, req.Header(), s.headerToContextKeyMappings) ctx = context.WithValue(ctx, store.SelectorContextKey{}, selector) ctx = context.WithValue(ctx, evaluator.ProtoVersionKey, "v1") @@ -166,7 +169,10 @@ func (s *FlagEvaluationService) EventStream( s.logger.Debug("starting event stream for request") requestNotificationChan := make(chan service.Notification, 1) selectorExpression := req.Header().Get(flagdService.FLAGD_SELECTOR_HEADER) - selector := store.NewSelector(selectorExpression) + selector, err := store.NewSelector(selectorExpression) + if err != nil { + return connect.NewError(connect.CodeInvalidArgument, err) + } s.eventingConfiguration.Subscribe(ctx, req, &selector, requestNotificationChan) defer s.eventingConfiguration.Unsubscribe(req) @@ -212,7 +218,10 @@ func (s *FlagEvaluationService) ResolveBoolean( defer span.End() selectorExpression := req.Header().Get(flagdService.FLAGD_SELECTOR_HEADER) - selector := store.NewSelector(selectorExpression) + selector, err := store.NewSelector(selectorExpression) + if err != nil { + return nil, connect.NewError(connect.CodeInvalidArgument, err) + } ctx = context.WithValue(ctx, store.SelectorContextKey{}, selector) ctx = context.WithValue(ctx, evaluator.ProtoVersionKey, "v1") @@ -245,7 +254,10 @@ func (s *FlagEvaluationService) ResolveString( defer span.End() selectorExpression := req.Header().Get(flagdService.FLAGD_SELECTOR_HEADER) - selector := store.NewSelector(selectorExpression) + selector, err := store.NewSelector(selectorExpression) + if err != nil { + return nil, connect.NewError(connect.CodeInvalidArgument, err) + } ctx = context.WithValue(ctx, store.SelectorContextKey{}, selector) ctx = context.WithValue(ctx, evaluator.ProtoVersionKey, "v1") @@ -278,7 +290,10 @@ func (s *FlagEvaluationService) ResolveInt( defer span.End() selectorExpression := req.Header().Get(flagdService.FLAGD_SELECTOR_HEADER) - selector := store.NewSelector(selectorExpression) + selector, err := store.NewSelector(selectorExpression) + if err != nil { + return nil, connect.NewError(connect.CodeInvalidArgument, err) + } ctx = context.WithValue(ctx, store.SelectorContextKey{}, selector) ctx = context.WithValue(ctx, evaluator.ProtoVersionKey, "v1") @@ -311,7 +326,10 @@ func (s *FlagEvaluationService) ResolveFloat( defer span.End() selectorExpression := req.Header().Get(flagdService.FLAGD_SELECTOR_HEADER) - selector := store.NewSelector(selectorExpression) + selector, err := store.NewSelector(selectorExpression) + if err != nil { + return nil, connect.NewError(connect.CodeInvalidArgument, err) + } ctx = context.WithValue(ctx, store.SelectorContextKey{}, selector) ctx = context.WithValue(ctx, evaluator.ProtoVersionKey, "v1") @@ -344,7 +362,10 @@ func (s *FlagEvaluationService) ResolveObject( defer span.End() selectorExpression := req.Header().Get(flagdService.FLAGD_SELECTOR_HEADER) - selector := store.NewSelector(selectorExpression) + selector, err := store.NewSelector(selectorExpression) + if err != nil { + return nil, connect.NewError(connect.CodeInvalidArgument, err) + } ctx = context.WithValue(ctx, store.SelectorContextKey{}, selector) ctx = context.WithValue(ctx, evaluator.ProtoVersionKey, "v1") diff --git a/flagd/pkg/service/flag-evaluation/flag_evaluator_v2.go b/flagd/pkg/service/flag-evaluation/flag_evaluator_v2.go index a78650776..d92584ff4 100644 --- a/flagd/pkg/service/flag-evaluation/flag_evaluator_v2.go +++ b/flagd/pkg/service/flag-evaluation/flag_evaluator_v2.go @@ -81,7 +81,10 @@ func (s *FlagEvaluationServiceV2) EventStream( s.logger.Debug("starting event stream for request") requestNotificationChan := make(chan service.Notification, 1) selectorExpression := req.Header().Get(flagdService.FLAGD_SELECTOR_HEADER) - selector := store.NewSelector(selectorExpression) + selector, err := store.NewSelector(selectorExpression) + if err != nil { + return connect.NewError(connect.CodeInvalidArgument, err) + } s.eventingConfiguration.Subscribe(ctx, req, &selector, requestNotificationChan) defer s.eventingConfiguration.Unsubscribe(req) @@ -123,11 +126,14 @@ func (s *FlagEvaluationServiceV2) ResolveBoolean( ctx context.Context, req *connect.Request[evalV2.ResolveBooleanRequest], ) (*connect.Response[evalV2.ResolveBooleanResponse], error) { - ctx, span := s.startResolveV2(ctx, "resolveBoolean", req.Header()) + ctx, span, err := s.startResolveV2(ctx, "resolveBoolean", req.Header()) defer span.End() + if err != nil { + return nil, err + } res := connect.NewResponse(&evalV2.ResolveBooleanResponse{}) - err := resolveV2( + err = resolveV2( ctx, s.logger, s.eval.ResolveBooleanValue, req.Header(), req.Msg.GetFlagKey(), req.Msg.GetContext(), &booleanResponseV2{evalV2Resp: res}, @@ -142,11 +148,14 @@ func (s *FlagEvaluationServiceV2) ResolveString( ctx context.Context, req *connect.Request[evalV2.ResolveStringRequest], ) (*connect.Response[evalV2.ResolveStringResponse], error) { - ctx, span := s.startResolveV2(ctx, "resolveString", req.Header()) + ctx, span, err := s.startResolveV2(ctx, "resolveString", req.Header()) defer span.End() + if err != nil { + return nil, err + } res := connect.NewResponse(&evalV2.ResolveStringResponse{}) - err := resolveV2( + err = resolveV2( ctx, s.logger, s.eval.ResolveStringValue, req.Header(), req.Msg.GetFlagKey(), req.Msg.GetContext(), &stringResponseV2{evalV2Resp: res}, @@ -161,11 +170,14 @@ func (s *FlagEvaluationServiceV2) ResolveInt( ctx context.Context, req *connect.Request[evalV2.ResolveIntRequest], ) (*connect.Response[evalV2.ResolveIntResponse], error) { - ctx, span := s.startResolveV2(ctx, "resolveInt", req.Header()) + ctx, span, err := s.startResolveV2(ctx, "resolveInt", req.Header()) defer span.End() + if err != nil { + return nil, err + } res := connect.NewResponse(&evalV2.ResolveIntResponse{}) - err := resolveV2( + err = resolveV2( ctx, s.logger, s.eval.ResolveIntValue, req.Header(), req.Msg.GetFlagKey(), req.Msg.GetContext(), &intResponseV2{evalV2Resp: res}, @@ -180,11 +192,14 @@ func (s *FlagEvaluationServiceV2) ResolveFloat( ctx context.Context, req *connect.Request[evalV2.ResolveFloatRequest], ) (*connect.Response[evalV2.ResolveFloatResponse], error) { - ctx, span := s.startResolveV2(ctx, "resolveFloat", req.Header()) + ctx, span, err := s.startResolveV2(ctx, "resolveFloat", req.Header()) defer span.End() + if err != nil { + return nil, err + } res := connect.NewResponse(&evalV2.ResolveFloatResponse{}) - err := resolveV2( + err = resolveV2( ctx, s.logger, s.eval.ResolveFloatValue, req.Header(), req.Msg.GetFlagKey(), req.Msg.GetContext(), &floatResponseV2{evalV2Resp: res}, @@ -199,11 +214,14 @@ func (s *FlagEvaluationServiceV2) ResolveObject( ctx context.Context, req *connect.Request[evalV2.ResolveObjectRequest], ) (*connect.Response[evalV2.ResolveObjectResponse], error) { - ctx, span := s.startResolveV2(ctx, "resolveObject", req.Header()) + ctx, span, err := s.startResolveV2(ctx, "resolveObject", req.Header()) defer span.End() + if err != nil { + return nil, err + } res := connect.NewResponse(&evalV2.ResolveObjectResponse{}) - err := resolveV2( + err = resolveV2( ctx, s.logger, s.eval.ResolveObjectValue, req.Header(), req.Msg.GetFlagKey(), req.Msg.GetContext(), &objectResponseV2{evalV2Resp: res}, @@ -268,14 +286,17 @@ func resolveV2[T constraints](ctx context.Context, logger *logger.Logger, resolv // startResolveV2 initialises tracing and selector context common to every Resolve* method. func (s *FlagEvaluationServiceV2) startResolveV2( ctx context.Context, spanName string, header http.Header, -) (context.Context, trace.Span) { +) (context.Context, trace.Span, error) { ctx, span := s.flagEvalTracer.Start(ctx, spanName, trace.WithSpanKind(trace.SpanKindServer)) selectorExpression := header.Get(flagdService.FLAGD_SELECTOR_HEADER) - selector := store.NewSelector(selectorExpression) + selector, err := store.NewSelector(selectorExpression) + if err != nil { + return ctx, span, connect.NewError(connect.CodeInvalidArgument, err) + } ctx = context.WithValue(ctx, store.SelectorContextKey{}, selector) - return ctx, span + return ctx, span, nil } // recordResolveErrorV2 records an evaluation error on the active span. diff --git a/flagd/pkg/service/flag-evaluation/ofrep/handler.go b/flagd/pkg/service/flag-evaluation/ofrep/handler.go index ac3f2ddf8..bfeb9ea5a 100644 --- a/flagd/pkg/service/flag-evaluation/ofrep/handler.go +++ b/flagd/pkg/service/flag-evaluation/ofrep/handler.go @@ -97,7 +97,15 @@ func (h *handler) HandleFlagEvaluation(w http.ResponseWriter, r *http.Request) { } evaluationContext := flagdContext(h.Logger, requestID, request, h.contextValues, r.Header, h.headerToContextKeyMappings) selectorExpression := r.Header.Get(service.FLAGD_SELECTOR_HEADER) - selector := store.NewSelector(selectorExpression) + selector, err := store.NewSelector(selectorExpression) + if err != nil { + h.writeJSONToResponse(http.StatusBadRequest, ofrep.EvaluationError{ + Key: flagKey, + ErrorCode: model.GeneralErrorCode, + ErrorDetails: fmt.Sprintf("invalid selector: %v", err), + }, w) + return + } ctx := context.WithValue(r.Context(), store.SelectorContextKey{}, selector) evaluation := h.evaluator.ResolveAsAnyValue(ctx, requestID, flagKey, evaluationContext) @@ -122,7 +130,12 @@ func (h *handler) HandleBulkEvaluation(w http.ResponseWriter, r *http.Request) { evaluationContext := flagdContext(h.Logger, requestID, request, h.contextValues, r.Header, h.headerToContextKeyMappings) selectorExpression := r.Header.Get(service.FLAGD_SELECTOR_HEADER) - selector := store.NewSelector(selectorExpression) + selector, err := store.NewSelector(selectorExpression) + if err != nil { + res := ofrep.BulkEvaluationContextErrorFrom(model.GeneralErrorCode, fmt.Sprintf("invalid selector: %v", err)) + h.writeJSONToResponse(http.StatusBadRequest, res, w) + return + } ctx := context.WithValue(r.Context(), store.SelectorContextKey{}, selector) evaluations, metadata, err := h.evaluator.ResolveAllValues(ctx, requestID, evaluationContext) diff --git a/flagd/pkg/service/flag-sync/handler.go b/flagd/pkg/service/flag-sync/handler.go index 559a33f44..59d2389a8 100644 --- a/flagd/pkg/service/flag-sync/handler.go +++ b/flagd/pkg/service/flag-sync/handler.go @@ -62,7 +62,10 @@ func (s syncHandler) SyncFlags(req *syncv1.SyncFlagsRequest, server syncv1grpc.F }() watcher := make(chan store.FlagQueryResult, 1) - selector := store.NewSelector(selectorExpression) + selector, err := store.NewSelector(selectorExpression) + if err != nil { + return status.Error(codes.InvalidArgument, err.Error()) + } ctx := server.Context() syncContextMap := make(map[string]any) @@ -166,7 +169,10 @@ func (s syncHandler) FetchAllFlags(ctx context.Context, req *syncv1.FetchAllFlag *syncv1.FetchAllFlagsResponse, error, ) { selectorExpression := s.getSelectorExpression(ctx, req) - selector := store.NewSelector(selectorExpression) + selector, err := store.NewSelector(selectorExpression) + if err != nil { + return nil, status.Error(codes.InvalidArgument, err.Error()) + } flags, _, err := s.store.GetAll(ctx, &selector) if err != nil { s.log.Error(fmt.Sprintf("error retrieving flags from store: %v", err)) From a0011feff6d2df7d34351d1140656a1616cd5a21 Mon Sep 17 00:00:00 2001 From: Alexandre Chakroun Date: Mon, 1 Jun 2026 10:42:43 +0200 Subject: [PATCH 2/6] Review query.go API Signed-off-by: Alexandre Chakroun --- core/pkg/store/query.go | 26 +++++++------------------- core/pkg/store/query_test.go | 28 ++++++++++++---------------- core/pkg/store/store.go | 6 +++--- core/pkg/store/store_test.go | 22 +++++++++++----------- 4 files changed, 33 insertions(+), 49 deletions(-) diff --git a/core/pkg/store/query.go b/core/pkg/store/query.go index 8f3d5d781..e89c52250 100644 --- a/core/pkg/store/query.go +++ b/core/pkg/store/query.go @@ -52,16 +52,6 @@ func NewSelector(selectorExpression string) (Selector, error) { return Selector{indexMap: m}, nil } -// NewSourceSelector creates a Selector that queries by source. -func NewSourceSelector(source string) Selector { - return Selector{indexMap: map[string]string{sourceIndex: source}} -} - -// NewFlagSetIdSelector creates a Selector that queries by flagSetId. -func NewFlagSetIdSelector(flagSetId string) Selector { - return Selector{indexMap: map[string]string{flagSetIdIndex: flagSetId}} -} - func expressionToMap(sExp string) map[string]string { selectorMap := make(map[string]string) if sExp == "" { @@ -89,17 +79,15 @@ func expressionToMap(sExp string) map[string]string { return selectorMap } -// WithIndex returns a new Selector with the given key-value pair added. -// Returns an error if the key is not a valid user-facing selector key. -func (s Selector) WithIndex(key string, value string) (Selector, error) { - if _, ok := validSelectorKeys[key]; !ok { - return s, fmt.Errorf("invalid selector key %q, valid keys: %q, %q", key, flagSetIdIndex, sourceIndex) - } - return s.withIndex(key, value), nil -} +func (s Selector) WithSource(source string) Selector { return s.withIndex(sourceIndex, source) } +func (s Selector) WithFlagSetId(id string) Selector { return s.withIndex(flagSetIdIndex, id) } +func (s Selector) withKey(key string) Selector { return s.withIndex(keyIndex, key) } -func (s Selector) withIndex(key string, value string) Selector { +func (s Selector) withIndex(key, value string) Selector { m := maps.Clone(s.indexMap) + if m == nil { + m = make(map[string]string, 1) + } m[key] = value return Selector{indexMap: m} } diff --git a/core/pkg/store/query_test.go b/core/pkg/store/query_test.go index f141b172c..76d3a7145 100644 --- a/core/pkg/store/query_test.go +++ b/core/pkg/store/query_test.go @@ -50,27 +50,23 @@ func TestSelector_IsEmpty(t *testing.T) { } } -func TestSelector_WithIndex(t *testing.T) { - oldS := Selector{indexMap: map[string]string{"source": "abc"}} - newS, err := oldS.WithIndex("flagSetId", "1234") - if err != nil { - t.Fatalf("WithIndex returned unexpected error: %v", err) +func TestSelector_WithSourceAndFlagSetId(t *testing.T) { + s := Selector{}.WithSource("abc") + if s.indexMap[sourceIndex] != "abc" { + t.Errorf("WithSource did not set source") } - if newS.indexMap["source"] != "abc" { - t.Errorf("WithIndex did not preserve existing keys") + s2 := s.WithFlagSetId("1234") + if s2.indexMap[sourceIndex] != "abc" { + t.Errorf("WithFlagSetId did not preserve source") } - if newS.indexMap["flagSetId"] != "1234" { - t.Errorf("WithIndex did not add new key") - } - // Ensure original is unchanged - if _, ok := oldS.indexMap["flagSetId"]; ok { - t.Errorf("WithIndex mutated original selector") + if s2.indexMap[flagSetIdIndex] != "1234" { + t.Errorf("WithFlagSetId did not set flagSetId") } - _, err = oldS.WithIndex("invalidKey", "val") - if err == nil { - t.Errorf("WithIndex should return error for invalid key") + // Ensure original is unchanged + if _, ok := s.indexMap[flagSetIdIndex]; ok { + t.Errorf("WithFlagSetId mutated original selector") } } diff --git a/core/pkg/store/store.go b/core/pkg/store/store.go index e28fc1ef6..5cc53c10f 100644 --- a/core/pkg/store/store.go +++ b/core/pkg/store/store.go @@ -153,7 +153,7 @@ func (s *Store) Get(_ context.Context, key string, selector *Selector) (model.Fl // if present, use the selector to query the flags if !selector.IsEmpty() { - selector := selector.withIndex("key", key) + selector := selector.withKey(key) indexId, constraints := selector.ToQuery() s.logger.Debug(fmt.Sprintf("getting flag with query: %s, %v", indexId, constraints)) raw, err := txn.First(flagsTable, indexId, constraints...) @@ -274,7 +274,7 @@ func (s *Store) Update( seenFlagSetIds[id.flagSetId] = struct{}{} } for fsi := range seenFlagSetIds { - sel := NewFlagSetIdSelector(fsi).withIndex(sourceIndex, source) + sel := Selector{}.WithFlagSetId(fsi).WithSource(source) indexId, constraints := sel.ToQuery() it, err := txn.Get(flagsTable, indexId, constraints...) if err != nil { @@ -284,7 +284,7 @@ func (s *Store) Update( oldFlags = append(oldFlags, s.collect(it)...) } } else { - sel := NewSourceSelector(source) + sel := Selector{}.WithSource(source) indexId, constraints := sel.ToQuery() it, err := txn.Get(flagsTable, indexId, constraints...) if err != nil { diff --git a/core/pkg/store/store_test.go b/core/pkg/store/store_test.go index bd8631e0f..693c021f6 100644 --- a/core/pkg/store/store_test.go +++ b/core/pkg/store/store_test.go @@ -238,8 +238,8 @@ func TestGet(t *testing.T) { sources := []string{sourceA.Name, sourceB.Name, sourceC.Name} - sourceASelector := NewSourceSelector(sourceA.Name) - flagSetIdCSelector := NewFlagSetIdSelector(flagSetIdC) + sourceASelector := Selector{}.WithSource(sourceA.Name) + flagSetIdCSelector := Selector{}.WithFlagSetId(flagSetIdC) t.Parallel() tests := []struct { @@ -377,10 +377,10 @@ func TestGetAllNoWatcher(t *testing.T) { sources := []string{sourceA.Name, sourceB.Name, sourceC.Name} - sourceASelector := NewSourceSelector(sourceA.Name) - flagSetIdCSelector := NewFlagSetIdSelector(flagSetIdC) + sourceASelector := Selector{}.WithSource(sourceA.Name) + flagSetIdCSelector := Selector{}.WithFlagSetId(flagSetIdC) // #1708 Until we decide on the Selector syntax, only a single key=value pair is supported - //flagSetIdAndCSelector := NewFlagSetIdSelector(flagSetIdC).withIndex(sourceIndex, sourceC.Name) + //flagSetIdAndCSelector := Selector{}.WithFlagSetId(flagSetIdC).withIndex(sourceIndex, sourceC.Name) t.Parallel() tests := []struct { @@ -506,10 +506,10 @@ func TestWatch(t *testing.T) { pauseTime := 100 * time.Millisecond // time for updates to settle timeout := 1000 * time.Millisecond // time to make sure we get enough updates, and no extras - sourceASelector := NewSourceSelector(sourceA) - flagSetIdCSelector := NewFlagSetIdSelector(myFlagSetId) + sourceASelector := Selector{}.WithSource(sourceA) + flagSetIdCSelector := Selector{}.WithFlagSetId(myFlagSetId) emptySelector := Selector{} - sourceCSelector := NewSourceSelector(sourceC) + sourceCSelector := Selector{}.WithSource(sourceC) tests := []struct { name string @@ -786,15 +786,15 @@ func TestQueryMetadata(t *testing.T) { // #1708 Until we decide on the Selector syntax, only a single key=value pair is supported // these tests should then also cover more complex selectors - selector := NewFlagSetIdSelector(nonExistingFlagSetId) + selector := Selector{}.WithFlagSetId(nonExistingFlagSetId) _, metadata, _ := store.GetAll(context.Background(), &selector) assert.Equal(t, metadata, model.Metadata{"flagSetId": nonExistingFlagSetId}, "metadata did not match expected") - selector = NewFlagSetIdSelector(nonExistingFlagSetId) + selector = Selector{}.WithFlagSetId(nonExistingFlagSetId) _, metadata, _ = store.Get(context.Background(), "key", &selector) assert.Equal(t, metadata, model.Metadata{"flagSetId": nonExistingFlagSetId}, "metadata did not match expected") - selector = NewSourceSelector(otherSource) + selector = Selector{}.WithSource(otherSource) _, metadata, _ = store.Get(context.Background(), "key", &selector) assert.Equal(t, metadata, model.Metadata{"source": otherSource}, "metadata did not match expected") } From 0abea5d14e74a5a8fde68ed673b041ab92801ba0 Mon Sep 17 00:00:00 2001 From: Alexandre Chakroun Date: Mon, 1 Jun 2026 13:50:04 +0200 Subject: [PATCH 3/6] sonar comments Signed-off-by: Alexandre Chakroun --- .../service/flag-evaluation/context_utils.go | 13 ++++++ .../service/flag-evaluation/flag_evaluator.go | 46 ++++++++----------- .../flag-evaluation/flag_evaluator_v1.go | 46 ++++++++----------- .../flag-evaluation/flag_evaluator_v2.go | 11 ++--- 4 files changed, 55 insertions(+), 61 deletions(-) diff --git a/flagd/pkg/service/flag-evaluation/context_utils.go b/flagd/pkg/service/flag-evaluation/context_utils.go index 72d7c00ad..710fab55d 100644 --- a/flagd/pkg/service/flag-evaluation/context_utils.go +++ b/flagd/pkg/service/flag-evaluation/context_utils.go @@ -2,8 +2,21 @@ package service import ( "net/http" + + "connectrpc.com/connect" + flagdService "github.com/open-feature/flagd/flagd/pkg/service" + "github.com/open-feature/flagd/core/pkg/store" ) +func selectorFromHeader(header http.Header) (store.Selector, error) { + expr := header.Get(flagdService.FLAGD_SELECTOR_HEADER) + s, err := store.NewSelector(expr) + if err != nil { + return store.Selector{}, connect.NewError(connect.CodeInvalidArgument, err) + } + return s, nil +} + // MergeContextsAndHeaders merges evaluation contexts with static context values and header-based context. // highest priority > header-context-from-cli > static-context-from-cli > request-context > lowest priority // Header names are matched case-insensitively according to HTTP specification. diff --git a/flagd/pkg/service/flag-evaluation/flag_evaluator.go b/flagd/pkg/service/flag-evaluation/flag_evaluator.go index 7f4dbf69a..3826c2f34 100644 --- a/flagd/pkg/service/flag-evaluation/flag_evaluator.go +++ b/flagd/pkg/service/flag-evaluation/flag_evaluator.go @@ -14,7 +14,6 @@ import ( "github.com/open-feature/flagd/core/pkg/service" "github.com/open-feature/flagd/core/pkg/store" "github.com/open-feature/flagd/core/pkg/telemetry" - flagdService "github.com/open-feature/flagd/flagd/pkg/service" "github.com/rs/xid" "go.opentelemetry.io/otel" "go.opentelemetry.io/otel/attribute" @@ -75,10 +74,9 @@ func (s *OldFlagEvaluationService) ResolveAll( Flags: make(map[string]*schemaV1.AnyFlag), } - selectorExpression := req.Header().Get(flagdService.FLAGD_SELECTOR_HEADER) - selector, err := store.NewSelector(selectorExpression) + selector, err := selectorFromHeader(req.Header()) if err != nil { - return nil, connect.NewError(connect.CodeInvalidArgument, err) + return nil, err } ctx = context.WithValue(ctx, store.SelectorContextKey{}, selector) @@ -145,10 +143,9 @@ func (s *OldFlagEvaluationService) EventStream( s.logger.Debug(fmt.Sprintf("starting event stream for request")) requestNotificationChan := make(chan service.Notification, 1) - selectorExpression := req.Header().Get(flagdService.FLAGD_SELECTOR_HEADER) - selector, err := store.NewSelector(selectorExpression) + selector, err := selectorFromHeader(req.Header()) if err != nil { - return connect.NewError(connect.CodeInvalidArgument, err) + return err } s.eventingConfiguration.Subscribe(ctx, req, &selector, requestNotificationChan) defer s.eventingConfiguration.Unsubscribe(req) @@ -191,14 +188,13 @@ func (s *OldFlagEvaluationService) ResolveBoolean( ctx, span := s.flagEvalTracer.Start(ctx, "resolveBoolean", trace.WithSpanKind(trace.SpanKindServer)) defer span.End() res := connect.NewResponse(&schemaV1.ResolveBooleanResponse{}) - selectorExpression := req.Header().Get(flagdService.FLAGD_SELECTOR_HEADER) - selector, err := store.NewSelector(selectorExpression) + selector, err := selectorFromHeader(req.Header()) if err != nil { - return nil, connect.NewError(connect.CodeInvalidArgument, err) + return nil, err } ctx = context.WithValue(ctx, store.SelectorContextKey{}, selector) - err := resolve[bool]( + err = resolve[bool]( ctx, s.logger, s.eval.ResolveBooleanValue, @@ -226,15 +222,14 @@ func (s *OldFlagEvaluationService) ResolveString( ctx, span := s.flagEvalTracer.Start(ctx, "resolveString", trace.WithSpanKind(trace.SpanKindServer)) defer span.End() - selectorExpression := req.Header().Get(flagdService.FLAGD_SELECTOR_HEADER) - selector, err := store.NewSelector(selectorExpression) + selector, err := selectorFromHeader(req.Header()) if err != nil { - return nil, connect.NewError(connect.CodeInvalidArgument, err) + return nil, err } ctx = context.WithValue(ctx, store.SelectorContextKey{}, selector) res := connect.NewResponse(&schemaV1.ResolveStringResponse{}) - err := resolve[string]( + err = resolve[string]( ctx, s.logger, s.eval.ResolveStringValue, @@ -262,15 +257,14 @@ func (s *OldFlagEvaluationService) ResolveInt( ctx, span := s.flagEvalTracer.Start(ctx, "resolveInt", trace.WithSpanKind(trace.SpanKindServer)) defer span.End() - selectorExpression := req.Header().Get(flagdService.FLAGD_SELECTOR_HEADER) - selector, err := store.NewSelector(selectorExpression) + selector, err := selectorFromHeader(req.Header()) if err != nil { - return nil, connect.NewError(connect.CodeInvalidArgument, err) + return nil, err } ctx = context.WithValue(ctx, store.SelectorContextKey{}, selector) res := connect.NewResponse(&schemaV1.ResolveIntResponse{}) - err := resolve[int64]( + err = resolve[int64]( ctx, s.logger, s.eval.ResolveIntValue, @@ -298,15 +292,14 @@ func (s *OldFlagEvaluationService) ResolveFloat( ctx, span := s.flagEvalTracer.Start(ctx, "resolveFloat", trace.WithSpanKind(trace.SpanKindServer)) defer span.End() - selectorExpression := req.Header().Get(flagdService.FLAGD_SELECTOR_HEADER) - selector, err := store.NewSelector(selectorExpression) + selector, err := selectorFromHeader(req.Header()) if err != nil { - return nil, connect.NewError(connect.CodeInvalidArgument, err) + return nil, err } ctx = context.WithValue(ctx, store.SelectorContextKey{}, selector) res := connect.NewResponse(&schemaV1.ResolveFloatResponse{}) - err := resolve[float64]( + err = resolve[float64]( ctx, s.logger, s.eval.ResolveFloatValue, @@ -334,15 +327,14 @@ func (s *OldFlagEvaluationService) ResolveObject( ctx, span := s.flagEvalTracer.Start(ctx, "resolveObject", trace.WithSpanKind(trace.SpanKindServer)) defer span.End() - selectorExpression := req.Header().Get(flagdService.FLAGD_SELECTOR_HEADER) - selector, err := store.NewSelector(selectorExpression) + selector, err := selectorFromHeader(req.Header()) if err != nil { - return nil, connect.NewError(connect.CodeInvalidArgument, err) + return nil, err } ctx = context.WithValue(ctx, store.SelectorContextKey{}, selector) res := connect.NewResponse(&schemaV1.ResolveObjectResponse{}) - err := resolve[map[string]any]( + err = resolve[map[string]any]( ctx, s.logger, s.eval.ResolveObjectValue, diff --git a/flagd/pkg/service/flag-evaluation/flag_evaluator_v1.go b/flagd/pkg/service/flag-evaluation/flag_evaluator_v1.go index 496122d62..8182e5ebb 100644 --- a/flagd/pkg/service/flag-evaluation/flag_evaluator_v1.go +++ b/flagd/pkg/service/flag-evaluation/flag_evaluator_v1.go @@ -13,7 +13,6 @@ import ( "github.com/open-feature/flagd/core/pkg/service" "github.com/open-feature/flagd/core/pkg/store" "github.com/open-feature/flagd/core/pkg/telemetry" - flagdService "github.com/open-feature/flagd/flagd/pkg/service" "github.com/rs/xid" "go.opentelemetry.io/otel" "go.opentelemetry.io/otel/attribute" @@ -75,10 +74,9 @@ func (s *FlagEvaluationService) ResolveAll( Flags: make(map[string]*evalV1.AnyFlag), } - selectorExpression := req.Header().Get(flagdService.FLAGD_SELECTOR_HEADER) - selector, err := store.NewSelector(selectorExpression) + selector, err := selectorFromHeader(req.Header()) if err != nil { - return nil, connect.NewError(connect.CodeInvalidArgument, err) + return nil, err } evaluationContext := mergeContexts(req.Msg.GetContext().AsMap(), s.contextValues, req.Header(), s.headerToContextKeyMappings) ctx = context.WithValue(ctx, store.SelectorContextKey{}, selector) @@ -168,10 +166,9 @@ func (s *FlagEvaluationService) EventStream( s.logger.Debug("starting event stream for request") requestNotificationChan := make(chan service.Notification, 1) - selectorExpression := req.Header().Get(flagdService.FLAGD_SELECTOR_HEADER) - selector, err := store.NewSelector(selectorExpression) + selector, err := selectorFromHeader(req.Header()) if err != nil { - return connect.NewError(connect.CodeInvalidArgument, err) + return err } s.eventingConfiguration.Subscribe(ctx, req, &selector, requestNotificationChan) defer s.eventingConfiguration.Unsubscribe(req) @@ -217,16 +214,15 @@ func (s *FlagEvaluationService) ResolveBoolean( ctx, span := s.flagEvalTracer.Start(ctx, "resolveBoolean", trace.WithSpanKind(trace.SpanKindServer)) defer span.End() - selectorExpression := req.Header().Get(flagdService.FLAGD_SELECTOR_HEADER) - selector, err := store.NewSelector(selectorExpression) + selector, err := selectorFromHeader(req.Header()) if err != nil { - return nil, connect.NewError(connect.CodeInvalidArgument, err) + return nil, err } ctx = context.WithValue(ctx, store.SelectorContextKey{}, selector) ctx = context.WithValue(ctx, evaluator.ProtoVersionKey, "v1") res := connect.NewResponse(&evalV1.ResolveBooleanResponse{}) - err := resolve( + err = resolve( ctx, s.logger, s.eval.ResolveBooleanValue, @@ -253,16 +249,15 @@ func (s *FlagEvaluationService) ResolveString( ctx, span := s.flagEvalTracer.Start(ctx, "resolveString", trace.WithSpanKind(trace.SpanKindServer)) defer span.End() - selectorExpression := req.Header().Get(flagdService.FLAGD_SELECTOR_HEADER) - selector, err := store.NewSelector(selectorExpression) + selector, err := selectorFromHeader(req.Header()) if err != nil { - return nil, connect.NewError(connect.CodeInvalidArgument, err) + return nil, err } ctx = context.WithValue(ctx, store.SelectorContextKey{}, selector) ctx = context.WithValue(ctx, evaluator.ProtoVersionKey, "v1") res := connect.NewResponse(&evalV1.ResolveStringResponse{}) - err := resolve( + err = resolve( ctx, s.logger, s.eval.ResolveStringValue, @@ -289,16 +284,15 @@ func (s *FlagEvaluationService) ResolveInt( ctx, span := s.flagEvalTracer.Start(ctx, "resolveInt", trace.WithSpanKind(trace.SpanKindServer)) defer span.End() - selectorExpression := req.Header().Get(flagdService.FLAGD_SELECTOR_HEADER) - selector, err := store.NewSelector(selectorExpression) + selector, err := selectorFromHeader(req.Header()) if err != nil { - return nil, connect.NewError(connect.CodeInvalidArgument, err) + return nil, err } ctx = context.WithValue(ctx, store.SelectorContextKey{}, selector) ctx = context.WithValue(ctx, evaluator.ProtoVersionKey, "v1") res := connect.NewResponse(&evalV1.ResolveIntResponse{}) - err := resolve( + err = resolve( ctx, s.logger, s.eval.ResolveIntValue, @@ -325,16 +319,15 @@ func (s *FlagEvaluationService) ResolveFloat( ctx, span := s.flagEvalTracer.Start(ctx, "resolveFloat", trace.WithSpanKind(trace.SpanKindServer)) defer span.End() - selectorExpression := req.Header().Get(flagdService.FLAGD_SELECTOR_HEADER) - selector, err := store.NewSelector(selectorExpression) + selector, err := selectorFromHeader(req.Header()) if err != nil { - return nil, connect.NewError(connect.CodeInvalidArgument, err) + return nil, err } ctx = context.WithValue(ctx, store.SelectorContextKey{}, selector) ctx = context.WithValue(ctx, evaluator.ProtoVersionKey, "v1") res := connect.NewResponse(&evalV1.ResolveFloatResponse{}) - err := resolve( + err = resolve( ctx, s.logger, s.eval.ResolveFloatValue, @@ -361,16 +354,15 @@ func (s *FlagEvaluationService) ResolveObject( ctx, span := s.flagEvalTracer.Start(ctx, "resolveObject", trace.WithSpanKind(trace.SpanKindServer)) defer span.End() - selectorExpression := req.Header().Get(flagdService.FLAGD_SELECTOR_HEADER) - selector, err := store.NewSelector(selectorExpression) + selector, err := selectorFromHeader(req.Header()) if err != nil { - return nil, connect.NewError(connect.CodeInvalidArgument, err) + return nil, err } ctx = context.WithValue(ctx, store.SelectorContextKey{}, selector) ctx = context.WithValue(ctx, evaluator.ProtoVersionKey, "v1") res := connect.NewResponse(&evalV1.ResolveObjectResponse{}) - err := resolve( + err = resolve( ctx, s.logger, s.eval.ResolveObjectValue, diff --git a/flagd/pkg/service/flag-evaluation/flag_evaluator_v2.go b/flagd/pkg/service/flag-evaluation/flag_evaluator_v2.go index d92584ff4..f66711a63 100644 --- a/flagd/pkg/service/flag-evaluation/flag_evaluator_v2.go +++ b/flagd/pkg/service/flag-evaluation/flag_evaluator_v2.go @@ -15,7 +15,6 @@ import ( "github.com/open-feature/flagd/core/pkg/service" "github.com/open-feature/flagd/core/pkg/store" "github.com/open-feature/flagd/core/pkg/telemetry" - flagdService "github.com/open-feature/flagd/flagd/pkg/service" "github.com/rs/xid" "go.opentelemetry.io/otel" "go.opentelemetry.io/otel/codes" @@ -80,10 +79,9 @@ func (s *FlagEvaluationServiceV2) EventStream( s.logger.Debug("starting event stream for request") requestNotificationChan := make(chan service.Notification, 1) - selectorExpression := req.Header().Get(flagdService.FLAGD_SELECTOR_HEADER) - selector, err := store.NewSelector(selectorExpression) + selector, err := selectorFromHeader(req.Header()) if err != nil { - return connect.NewError(connect.CodeInvalidArgument, err) + return err } s.eventingConfiguration.Subscribe(ctx, req, &selector, requestNotificationChan) defer s.eventingConfiguration.Unsubscribe(req) @@ -289,10 +287,9 @@ func (s *FlagEvaluationServiceV2) startResolveV2( ) (context.Context, trace.Span, error) { ctx, span := s.flagEvalTracer.Start(ctx, spanName, trace.WithSpanKind(trace.SpanKindServer)) - selectorExpression := header.Get(flagdService.FLAGD_SELECTOR_HEADER) - selector, err := store.NewSelector(selectorExpression) + selector, err := selectorFromHeader(header) if err != nil { - return ctx, span, connect.NewError(connect.CodeInvalidArgument, err) + return ctx, span, err } ctx = context.WithValue(ctx, store.SelectorContextKey{}, selector) From 9cf88be0383c0a06a43c01c6715e068dfafc9120 Mon Sep 17 00:00:00 2001 From: Alexandre Chakroun Date: Mon, 1 Jun 2026 21:18:16 +0200 Subject: [PATCH 4/6] Correctly set context before returning early Signed-off-by: Alexandre Chakroun --- flagd/pkg/service/flag-evaluation/flag_evaluator_v2.go | 2 ++ flagd/pkg/service/flag-sync/handler.go | 1 + 2 files changed, 3 insertions(+) diff --git a/flagd/pkg/service/flag-evaluation/flag_evaluator_v2.go b/flagd/pkg/service/flag-evaluation/flag_evaluator_v2.go index f66711a63..dbb43830e 100644 --- a/flagd/pkg/service/flag-evaluation/flag_evaluator_v2.go +++ b/flagd/pkg/service/flag-evaluation/flag_evaluator_v2.go @@ -289,6 +289,8 @@ func (s *FlagEvaluationServiceV2) startResolveV2( selector, err := selectorFromHeader(header) if err != nil { + span.RecordError(err) + span.SetStatus(codes.Error, err.Error()) return ctx, span, err } ctx = context.WithValue(ctx, store.SelectorContextKey{}, selector) diff --git a/flagd/pkg/service/flag-sync/handler.go b/flagd/pkg/service/flag-sync/handler.go index 59d2389a8..d23d63903 100644 --- a/flagd/pkg/service/flag-sync/handler.go +++ b/flagd/pkg/service/flag-sync/handler.go @@ -64,6 +64,7 @@ func (s syncHandler) SyncFlags(req *syncv1.SyncFlagsRequest, server syncv1grpc.F watcher := make(chan store.FlagQueryResult, 1) selector, err := store.NewSelector(selectorExpression) if err != nil { + exitReason = "error" return status.Error(codes.InvalidArgument, err.Error()) } ctx := server.Context() From b353b2e2e725be41ad9091170edaa76fc382ec5b Mon Sep 17 00:00:00 2001 From: Alexandre Chakroun Date: Mon, 1 Jun 2026 23:32:54 +0200 Subject: [PATCH 5/6] Add integration tests Signed-off-by: Alexandre Chakroun --- .../flag-evaluation/flag_evaluator_test.go | 60 +++++++++++++++++++ .../flag-evaluation/flag_evaluator_v1_test.go | 60 +++++++++++++++++++ .../flag-evaluation/flag_evaluator_v2_test.go | 55 +++++++++++++++++ .../flag-evaluation/ofrep/handler_test.go | 36 +++++++++++ flagd/pkg/service/flag-sync/handler_test.go | 32 ++++++++++ 5 files changed, 243 insertions(+) diff --git a/flagd/pkg/service/flag-evaluation/flag_evaluator_test.go b/flagd/pkg/service/flag-evaluation/flag_evaluator_test.go index de3df1362..6ad6bd3d6 100644 --- a/flagd/pkg/service/flag-evaluation/flag_evaluator_test.go +++ b/flagd/pkg/service/flag-evaluation/flag_evaluator_test.go @@ -1046,3 +1046,63 @@ func Test_Readable_ErrorMessage(t *testing.T) { }) } } + +func TestInvalidSelector_OldFlagEvaluationService(t *testing.T) { + const invalidSelector = "invalidKey=val" + ctrl := gomock.NewController(t) + eval := mock.NewMockIEvaluator(ctrl) + metrics, _ := getMetricReader() + s := NewOldFlagEvaluationService(logger.NewLogger(nil, false), eval, &eventingConfiguration{}, metrics, nil) + + tests := []struct { + name string + call func() error + }{ + {"ResolveAll", func() error { + req := connect.NewRequest(&schemaV1.ResolveAllRequest{}) + req.Header().Set("Flagd-Selector", invalidSelector) + _, err := s.ResolveAll(context.Background(), req) + return err + }}, + {"ResolveBoolean", func() error { + req := connect.NewRequest(&schemaV1.ResolveBooleanRequest{FlagKey: "f"}) + req.Header().Set("Flagd-Selector", invalidSelector) + _, err := s.ResolveBoolean(context.Background(), req) + return err + }}, + {"ResolveString", func() error { + req := connect.NewRequest(&schemaV1.ResolveStringRequest{FlagKey: "f"}) + req.Header().Set("Flagd-Selector", invalidSelector) + _, err := s.ResolveString(context.Background(), req) + return err + }}, + {"ResolveInt", func() error { + req := connect.NewRequest(&schemaV1.ResolveIntRequest{FlagKey: "f"}) + req.Header().Set("Flagd-Selector", invalidSelector) + _, err := s.ResolveInt(context.Background(), req) + return err + }}, + {"ResolveFloat", func() error { + req := connect.NewRequest(&schemaV1.ResolveFloatRequest{FlagKey: "f"}) + req.Header().Set("Flagd-Selector", invalidSelector) + _, err := s.ResolveFloat(context.Background(), req) + return err + }}, + {"ResolveObject", func() error { + req := connect.NewRequest(&schemaV1.ResolveObjectRequest{FlagKey: "f"}) + req.Header().Set("Flagd-Selector", invalidSelector) + _, err := s.ResolveObject(context.Background(), req) + return err + }}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + err := tt.call() + require.Error(t, err) + var connectErr *connect.Error + require.True(t, errors.As(err, &connectErr)) + require.Equal(t, connect.CodeInvalidArgument, connectErr.Code()) + }) + } +} diff --git a/flagd/pkg/service/flag-evaluation/flag_evaluator_v1_test.go b/flagd/pkg/service/flag-evaluation/flag_evaluator_v1_test.go index 70dbf3bf6..24a15d181 100644 --- a/flagd/pkg/service/flag-evaluation/flag_evaluator_v1_test.go +++ b/flagd/pkg/service/flag-evaluation/flag_evaluator_v1_test.go @@ -1067,3 +1067,63 @@ func Test_mergeContexts(t *testing.T) { }) } } + +func TestInvalidSelector_FlagEvaluationService(t *testing.T) { + const invalidSelector = "invalidKey=val" + ctrl := gomock.NewController(t) + eval := mock.NewMockIEvaluator(ctrl) + metrics, _ := getMetricReader() + s := NewFlagEvaluationService(logger.NewLogger(nil, false), eval, &eventingConfiguration{}, metrics, nil, nil, 0) + + tests := []struct { + name string + call func() error + }{ + {"ResolveAll", func() error { + req := connect.NewRequest(&evalV1.ResolveAllRequest{}) + req.Header().Set("Flagd-Selector", invalidSelector) + _, err := s.ResolveAll(context.Background(), req) + return err + }}, + {"ResolveBoolean", func() error { + req := connect.NewRequest(&evalV1.ResolveBooleanRequest{FlagKey: "f"}) + req.Header().Set("Flagd-Selector", invalidSelector) + _, err := s.ResolveBoolean(context.Background(), req) + return err + }}, + {"ResolveString", func() error { + req := connect.NewRequest(&evalV1.ResolveStringRequest{FlagKey: "f"}) + req.Header().Set("Flagd-Selector", invalidSelector) + _, err := s.ResolveString(context.Background(), req) + return err + }}, + {"ResolveInt", func() error { + req := connect.NewRequest(&evalV1.ResolveIntRequest{FlagKey: "f"}) + req.Header().Set("Flagd-Selector", invalidSelector) + _, err := s.ResolveInt(context.Background(), req) + return err + }}, + {"ResolveFloat", func() error { + req := connect.NewRequest(&evalV1.ResolveFloatRequest{FlagKey: "f"}) + req.Header().Set("Flagd-Selector", invalidSelector) + _, err := s.ResolveFloat(context.Background(), req) + return err + }}, + {"ResolveObject", func() error { + req := connect.NewRequest(&evalV1.ResolveObjectRequest{FlagKey: "f"}) + req.Header().Set("Flagd-Selector", invalidSelector) + _, err := s.ResolveObject(context.Background(), req) + return err + }}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + err := tt.call() + require.Error(t, err) + var connectErr *connect.Error + require.True(t, errors.As(err, &connectErr)) + require.Equal(t, connect.CodeInvalidArgument, connectErr.Code()) + }) + } +} diff --git a/flagd/pkg/service/flag-evaluation/flag_evaluator_v2_test.go b/flagd/pkg/service/flag-evaluation/flag_evaluator_v2_test.go index c64699a62..f88638b3e 100644 --- a/flagd/pkg/service/flag-evaluation/flag_evaluator_v2_test.go +++ b/flagd/pkg/service/flag-evaluation/flag_evaluator_v2_test.go @@ -2,6 +2,7 @@ package service import ( "context" + "errors" "testing" evalV2 "buf.build/gen/go/open-feature/flagd/protocolbuffers/go/flagd/evaluation/v2" @@ -121,3 +122,57 @@ func TestFlagEvaluationServiceV2_Fallback(t *testing.T) { }) } } + +func TestInvalidSelector_FlagEvaluationServiceV2(t *testing.T) { + const invalidSelector = "invalidKey=val" + ctrl := gomock.NewController(t) + eval := mock.NewMockIEvaluator(ctrl) + metrics, _ := getMetricReader() + s := NewFlagEvaluationServiceV2(logger.NewLogger(nil, false), eval, &eventingConfiguration{}, metrics, nil, nil, 0) + + tests := []struct { + name string + call func() error + }{ + {"ResolveBoolean", func() error { + req := connect.NewRequest(&evalV2.ResolveBooleanRequest{FlagKey: "f"}) + req.Header().Set("Flagd-Selector", invalidSelector) + _, err := s.ResolveBoolean(context.Background(), req) + return err + }}, + {"ResolveString", func() error { + req := connect.NewRequest(&evalV2.ResolveStringRequest{FlagKey: "f"}) + req.Header().Set("Flagd-Selector", invalidSelector) + _, err := s.ResolveString(context.Background(), req) + return err + }}, + {"ResolveInt", func() error { + req := connect.NewRequest(&evalV2.ResolveIntRequest{FlagKey: "f"}) + req.Header().Set("Flagd-Selector", invalidSelector) + _, err := s.ResolveInt(context.Background(), req) + return err + }}, + {"ResolveFloat", func() error { + req := connect.NewRequest(&evalV2.ResolveFloatRequest{FlagKey: "f"}) + req.Header().Set("Flagd-Selector", invalidSelector) + _, err := s.ResolveFloat(context.Background(), req) + return err + }}, + {"ResolveObject", func() error { + req := connect.NewRequest(&evalV2.ResolveObjectRequest{FlagKey: "f"}) + req.Header().Set("Flagd-Selector", invalidSelector) + _, err := s.ResolveObject(context.Background(), req) + return err + }}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + err := tt.call() + require.Error(t, err) + var connectErr *connect.Error + require.True(t, errors.As(err, &connectErr)) + require.Equal(t, connect.CodeInvalidArgument, connectErr.Code()) + }) + } +} diff --git a/flagd/pkg/service/flag-evaluation/ofrep/handler_test.go b/flagd/pkg/service/flag-evaluation/ofrep/handler_test.go index 3ae4d4635..b83756820 100644 --- a/flagd/pkg/service/flag-evaluation/ofrep/handler_test.go +++ b/flagd/pkg/service/flag-evaluation/ofrep/handler_test.go @@ -16,6 +16,7 @@ import ( "github.com/open-feature/flagd/core/pkg/logger" "github.com/open-feature/flagd/core/pkg/model" "github.com/open-feature/flagd/core/pkg/service/ofrep" + "github.com/stretchr/testify/require" "go.uber.org/mock/gomock" ) @@ -374,3 +375,38 @@ func TestFlagdContextDelegatesContextMerging(t *testing.T) { } } } + +func TestInvalidSelector_OFREPHandler(t *testing.T) { + const invalidSelector = "invalidKey=val" + log := logger.NewLogger(nil, false) + + t.Run("HandleFlagEvaluation", func(t *testing.T) { + eval := mock.NewMockIEvaluator(gomock.NewController(t)) + h := handler{Logger: log, evaluator: eval} + + req, _ := http.NewRequest(http.MethodPost, "/ofrep/v1/evaluate/flags/"+flagKey, bytes.NewReader([]byte{})) + req.Header.Set("Flagd-Selector", invalidSelector) + + recorder := httptest.NewRecorder() + router := mux.NewRouter() + router.HandleFunc(singleEvaluation, h.HandleFlagEvaluation) + router.ServeHTTP(recorder, req) + + require.Equal(t, http.StatusBadRequest, recorder.Code) + }) + + t.Run("HandleBulkEvaluation", func(t *testing.T) { + eval := mock.NewMockIEvaluator(gomock.NewController(t)) + h := handler{Logger: log, evaluator: eval} + + req, _ := http.NewRequest(http.MethodPost, "/ofrep/v1/evaluate/flags", bytes.NewReader([]byte{})) + req.Header.Set("Flagd-Selector", invalidSelector) + + recorder := httptest.NewRecorder() + router := mux.NewRouter() + router.HandleFunc(bulkEvaluation, h.HandleBulkEvaluation) + router.ServeHTTP(recorder, req) + + require.Equal(t, http.StatusBadRequest, recorder.Code) + }) +} diff --git a/flagd/pkg/service/flag-sync/handler_test.go b/flagd/pkg/service/flag-sync/handler_test.go index 7621be845..b869e409b 100644 --- a/flagd/pkg/service/flag-sync/handler_test.go +++ b/flagd/pkg/service/flag-sync/handler_test.go @@ -258,3 +258,35 @@ func TestSyncHandler_SelectorLocationPrecedence(t *testing.T) { }) } } + +func TestSyncHandler_InvalidSelector(t *testing.T) { + const invalidSelector = "invalidKey=val" + + flagStore, err := store.NewStore(logger.NewLogger(nil, false), []string{}) + require.NoError(t, err) + + h := syncHandler{ + store: flagStore, + log: logger.NewLogger(nil, false), + contextValues: map[string]any{}, + metricsRecorder: &telemetry.NoopMetricsRecorder{}, + } + + ctxWithInvalidSelector := metadata.NewIncomingContext( + context.Background(), + metadata.New(map[string]string{flagdService.FLAGD_SELECTOR_HEADER: invalidSelector}), + ) + + t.Run("SyncFlags", func(t *testing.T) { + stream := &mockSyncFlagsServer{ctx: ctxWithInvalidSelector, respReady: make(chan struct{}, 1)} + err := h.SyncFlags(&syncv1.SyncFlagsRequest{}, stream) + require.Error(t, err) + require.Equal(t, "rpc error: code = InvalidArgument desc = invalid selector key \"invalidKey\", valid keys: \"flagSetId\", \"source\"", err.Error()) + }) + + t.Run("FetchAllFlags", func(t *testing.T) { + _, err := h.FetchAllFlags(ctxWithInvalidSelector, &syncv1.FetchAllFlagsRequest{}) + require.Error(t, err) + require.Equal(t, "rpc error: code = InvalidArgument desc = invalid selector key \"invalidKey\", valid keys: \"flagSetId\", \"source\"", err.Error()) + }) +} From 9dac26efd319b541aebb7a01fe2b741526cd60db Mon Sep 17 00:00:00 2001 From: Alexandre Chakroun Date: Mon, 1 Jun 2026 23:49:13 +0200 Subject: [PATCH 6/6] Try and remove some duplication Signed-off-by: Alexandre Chakroun --- .../flag-evaluation/flag_evaluator_test.go | 32 +++++++++++-------- .../flag-evaluation/flag_evaluator_v1_test.go | 30 +++++------------ .../flag-evaluation/flag_evaluator_v2_test.go | 29 ++++------------- 3 files changed, 34 insertions(+), 57 deletions(-) diff --git a/flagd/pkg/service/flag-evaluation/flag_evaluator_test.go b/flagd/pkg/service/flag-evaluation/flag_evaluator_test.go index 6ad6bd3d6..d7e45797c 100644 --- a/flagd/pkg/service/flag-evaluation/flag_evaluator_test.go +++ b/flagd/pkg/service/flag-evaluation/flag_evaluator_test.go @@ -1048,55 +1048,61 @@ func Test_Readable_ErrorMessage(t *testing.T) { } func TestInvalidSelector_OldFlagEvaluationService(t *testing.T) { - const invalidSelector = "invalidKey=val" ctrl := gomock.NewController(t) eval := mock.NewMockIEvaluator(ctrl) metrics, _ := getMetricReader() s := NewOldFlagEvaluationService(logger.NewLogger(nil, false), eval, &eventingConfiguration{}, metrics, nil) - tests := []struct { - name string - call func() error - }{ + assertInvalidSelectorConnect(t, []invalidSelectorCase{ {"ResolveAll", func() error { req := connect.NewRequest(&schemaV1.ResolveAllRequest{}) - req.Header().Set("Flagd-Selector", invalidSelector) + req.Header().Set("Flagd-Selector", invalidSelectorExpr) _, err := s.ResolveAll(context.Background(), req) return err }}, {"ResolveBoolean", func() error { req := connect.NewRequest(&schemaV1.ResolveBooleanRequest{FlagKey: "f"}) - req.Header().Set("Flagd-Selector", invalidSelector) + req.Header().Set("Flagd-Selector", invalidSelectorExpr) _, err := s.ResolveBoolean(context.Background(), req) return err }}, {"ResolveString", func() error { req := connect.NewRequest(&schemaV1.ResolveStringRequest{FlagKey: "f"}) - req.Header().Set("Flagd-Selector", invalidSelector) + req.Header().Set("Flagd-Selector", invalidSelectorExpr) _, err := s.ResolveString(context.Background(), req) return err }}, {"ResolveInt", func() error { req := connect.NewRequest(&schemaV1.ResolveIntRequest{FlagKey: "f"}) - req.Header().Set("Flagd-Selector", invalidSelector) + req.Header().Set("Flagd-Selector", invalidSelectorExpr) _, err := s.ResolveInt(context.Background(), req) return err }}, {"ResolveFloat", func() error { req := connect.NewRequest(&schemaV1.ResolveFloatRequest{FlagKey: "f"}) - req.Header().Set("Flagd-Selector", invalidSelector) + req.Header().Set("Flagd-Selector", invalidSelectorExpr) _, err := s.ResolveFloat(context.Background(), req) return err }}, {"ResolveObject", func() error { req := connect.NewRequest(&schemaV1.ResolveObjectRequest{FlagKey: "f"}) - req.Header().Set("Flagd-Selector", invalidSelector) + req.Header().Set("Flagd-Selector", invalidSelectorExpr) _, err := s.ResolveObject(context.Background(), req) return err }}, - } + }) +} - for _, tt := range tests { +const invalidSelectorExpr = "invalidKey=val" + +type invalidSelectorCase struct { + name string + call func() error +} + +func assertInvalidSelectorConnect(t *testing.T, cases []invalidSelectorCase) { + t.Helper() + for _, tt := range cases { t.Run(tt.name, func(t *testing.T) { err := tt.call() require.Error(t, err) diff --git a/flagd/pkg/service/flag-evaluation/flag_evaluator_v1_test.go b/flagd/pkg/service/flag-evaluation/flag_evaluator_v1_test.go index 24a15d181..2c6fd2fa6 100644 --- a/flagd/pkg/service/flag-evaluation/flag_evaluator_v1_test.go +++ b/flagd/pkg/service/flag-evaluation/flag_evaluator_v1_test.go @@ -1069,61 +1069,47 @@ func Test_mergeContexts(t *testing.T) { } func TestInvalidSelector_FlagEvaluationService(t *testing.T) { - const invalidSelector = "invalidKey=val" ctrl := gomock.NewController(t) eval := mock.NewMockIEvaluator(ctrl) metrics, _ := getMetricReader() s := NewFlagEvaluationService(logger.NewLogger(nil, false), eval, &eventingConfiguration{}, metrics, nil, nil, 0) - tests := []struct { - name string - call func() error - }{ + assertInvalidSelectorConnect(t, []invalidSelectorCase{ {"ResolveAll", func() error { req := connect.NewRequest(&evalV1.ResolveAllRequest{}) - req.Header().Set("Flagd-Selector", invalidSelector) + req.Header().Set("Flagd-Selector", invalidSelectorExpr) _, err := s.ResolveAll(context.Background(), req) return err }}, {"ResolveBoolean", func() error { req := connect.NewRequest(&evalV1.ResolveBooleanRequest{FlagKey: "f"}) - req.Header().Set("Flagd-Selector", invalidSelector) + req.Header().Set("Flagd-Selector", invalidSelectorExpr) _, err := s.ResolveBoolean(context.Background(), req) return err }}, {"ResolveString", func() error { req := connect.NewRequest(&evalV1.ResolveStringRequest{FlagKey: "f"}) - req.Header().Set("Flagd-Selector", invalidSelector) + req.Header().Set("Flagd-Selector", invalidSelectorExpr) _, err := s.ResolveString(context.Background(), req) return err }}, {"ResolveInt", func() error { req := connect.NewRequest(&evalV1.ResolveIntRequest{FlagKey: "f"}) - req.Header().Set("Flagd-Selector", invalidSelector) + req.Header().Set("Flagd-Selector", invalidSelectorExpr) _, err := s.ResolveInt(context.Background(), req) return err }}, {"ResolveFloat", func() error { req := connect.NewRequest(&evalV1.ResolveFloatRequest{FlagKey: "f"}) - req.Header().Set("Flagd-Selector", invalidSelector) + req.Header().Set("Flagd-Selector", invalidSelectorExpr) _, err := s.ResolveFloat(context.Background(), req) return err }}, {"ResolveObject", func() error { req := connect.NewRequest(&evalV1.ResolveObjectRequest{FlagKey: "f"}) - req.Header().Set("Flagd-Selector", invalidSelector) + req.Header().Set("Flagd-Selector", invalidSelectorExpr) _, err := s.ResolveObject(context.Background(), req) return err }}, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - err := tt.call() - require.Error(t, err) - var connectErr *connect.Error - require.True(t, errors.As(err, &connectErr)) - require.Equal(t, connect.CodeInvalidArgument, connectErr.Code()) - }) - } + }) } diff --git a/flagd/pkg/service/flag-evaluation/flag_evaluator_v2_test.go b/flagd/pkg/service/flag-evaluation/flag_evaluator_v2_test.go index f88638b3e..4097cb7a0 100644 --- a/flagd/pkg/service/flag-evaluation/flag_evaluator_v2_test.go +++ b/flagd/pkg/service/flag-evaluation/flag_evaluator_v2_test.go @@ -2,7 +2,6 @@ package service import ( "context" - "errors" "testing" evalV2 "buf.build/gen/go/open-feature/flagd/protocolbuffers/go/flagd/evaluation/v2" @@ -124,55 +123,41 @@ func TestFlagEvaluationServiceV2_Fallback(t *testing.T) { } func TestInvalidSelector_FlagEvaluationServiceV2(t *testing.T) { - const invalidSelector = "invalidKey=val" ctrl := gomock.NewController(t) eval := mock.NewMockIEvaluator(ctrl) metrics, _ := getMetricReader() s := NewFlagEvaluationServiceV2(logger.NewLogger(nil, false), eval, &eventingConfiguration{}, metrics, nil, nil, 0) - tests := []struct { - name string - call func() error - }{ + assertInvalidSelectorConnect(t, []invalidSelectorCase{ {"ResolveBoolean", func() error { req := connect.NewRequest(&evalV2.ResolveBooleanRequest{FlagKey: "f"}) - req.Header().Set("Flagd-Selector", invalidSelector) + req.Header().Set("Flagd-Selector", invalidSelectorExpr) _, err := s.ResolveBoolean(context.Background(), req) return err }}, {"ResolveString", func() error { req := connect.NewRequest(&evalV2.ResolveStringRequest{FlagKey: "f"}) - req.Header().Set("Flagd-Selector", invalidSelector) + req.Header().Set("Flagd-Selector", invalidSelectorExpr) _, err := s.ResolveString(context.Background(), req) return err }}, {"ResolveInt", func() error { req := connect.NewRequest(&evalV2.ResolveIntRequest{FlagKey: "f"}) - req.Header().Set("Flagd-Selector", invalidSelector) + req.Header().Set("Flagd-Selector", invalidSelectorExpr) _, err := s.ResolveInt(context.Background(), req) return err }}, {"ResolveFloat", func() error { req := connect.NewRequest(&evalV2.ResolveFloatRequest{FlagKey: "f"}) - req.Header().Set("Flagd-Selector", invalidSelector) + req.Header().Set("Flagd-Selector", invalidSelectorExpr) _, err := s.ResolveFloat(context.Background(), req) return err }}, {"ResolveObject", func() error { req := connect.NewRequest(&evalV2.ResolveObjectRequest{FlagKey: "f"}) - req.Header().Set("Flagd-Selector", invalidSelector) + req.Header().Set("Flagd-Selector", invalidSelectorExpr) _, err := s.ResolveObject(context.Background(), req) return err }}, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - err := tt.call() - require.Error(t, err) - var connectErr *connect.Error - require.True(t, errors.As(err, &connectErr)) - require.Equal(t, connect.CodeInvalidArgument, connectErr.Code()) - }) - } + }) }