From 40a5ab44ab5677dd710200d7bcec6fbe55e6ff20 Mon Sep 17 00:00:00 2001 From: Simon Pasquier Date: Fri, 7 Jun 2024 21:02:50 +0200 Subject: [PATCH] chore: refactor PromQL enforcer This change renames the Enforcer struct to PromQLEnforcer to clarify its purpose. It also removes the error type casting in favor of error wrapping. Signed-off-by: Simon Pasquier --- README.md | 2 +- injectproxy/enforce.go | 45 +++++++++++++----- injectproxy/enforce_test.go | 94 +++++++++++++++++++------------------ injectproxy/routes.go | 72 ++++++++++------------------ main.go | 5 ++ 5 files changed, 112 insertions(+), 106 deletions(-) diff --git a/README.md b/README.md index ce67bb9a..17412ff4 100644 --- a/README.md +++ b/README.md @@ -144,7 +144,7 @@ prom-label-proxy \ `prom-label-proxy` will enforce the `tenant=~"prometheus|alertmanager"` label selector in all requests. -You can match the label value using a regular expression with the `-regex-match` option. For example: +You can match the label value using a regular expression with the `-regex-match` option. For example: ``` prom-label-proxy \ diff --git a/injectproxy/enforce.go b/injectproxy/enforce.go index 8cc1f28a..0666cad2 100644 --- a/injectproxy/enforce.go +++ b/injectproxy/enforce.go @@ -14,40 +14,59 @@ package injectproxy import ( + "errors" "fmt" "github.com/prometheus/prometheus/model/labels" "github.com/prometheus/prometheus/promql/parser" ) -type Enforcer struct { +// PromQLEnforcer can enforce label matchers in PromQL expressions. +type PromQLEnforcer struct { labelMatchers map[string]*labels.Matcher errorOnReplace bool } -func NewEnforcer(errorOnReplace bool, ms ...*labels.Matcher) *Enforcer { +func NewPromQLEnforcer(errorOnReplace bool, ms ...*labels.Matcher) *PromQLEnforcer { entries := make(map[string]*labels.Matcher) for _, matcher := range ms { entries[matcher.Name] = matcher } - return &Enforcer{ + return &PromQLEnforcer{ labelMatchers: entries, errorOnReplace: errorOnReplace, } } -type IllegalLabelMatcherError struct { - msg string -} +var ( + // ErrQueryParse is returned when the input query is invalid. + ErrQueryParse = errors.New("failed to parse query string") + + // ErrIllegalLabelMatcher is returned when the input query contains a conflicting label matcher. + ErrIllegalLabelMatcher = errors.New("conflicting label matcher") + + // ErrEnforceLabel is returned when the label matchers couldn't be enforced. + ErrEnforceLabel = errors.New("failed to enforce label") +) -func (e IllegalLabelMatcherError) Error() string { return e.msg } +// Enforce the label matchers in a PromQL expression. +func (ms *PromQLEnforcer) Enforce(q string) (string, error) { + expr, err := parser.ParseExpr(q) + if err != nil { + return "", fmt.Errorf("%w: %w", ErrQueryParse, err) + } + + if err := ms.EnforceNode(expr); err != nil { + if errors.Is(err, ErrIllegalLabelMatcher) { + return "", err + } -func newIllegalLabelMatcherError(existing string, replacement string) IllegalLabelMatcherError { - return IllegalLabelMatcherError{ - msg: fmt.Sprintf("label matcher value (%s) conflicts with injected value (%s)", existing, replacement), + return "", fmt.Errorf("%w: %w", ErrEnforceLabel, err) } + + return expr.String(), nil } // EnforceNode walks the given node recursively @@ -57,7 +76,7 @@ func newIllegalLabelMatcherError(existing string, replacement string) IllegalLab // their label enforcer is being potentially modified. // If a node's label matcher has the same name as a label matcher // of the given enforcer, then it will be replaced. -func (ms Enforcer) EnforceNode(node parser.Node) error { +func (ms PromQLEnforcer) EnforceNode(node parser.Node) error { switch n := node.(type) { case *parser.EvalStmt: if err := ms.EnforceNode(n.Expr); err != nil { @@ -140,14 +159,14 @@ func (ms Enforcer) EnforceNode(node parser.Node) error { // * if errorOnReplace is true, an error is returned, // * if errorOnReplace is false and the label matcher type is '=', the existing matcher is silently replaced. // * otherwise the existing matcher is preserved. -func (ms Enforcer) EnforceMatchers(targets []*labels.Matcher) ([]*labels.Matcher, error) { +func (ms PromQLEnforcer) EnforceMatchers(targets []*labels.Matcher) ([]*labels.Matcher, error) { var res []*labels.Matcher for _, target := range targets { if matcher, ok := ms.labelMatchers[target.Name]; ok { // matcher.String() returns something like "labelfoo=value" if ms.errorOnReplace && matcher.String() != target.String() { - return res, newIllegalLabelMatcherError(matcher.String(), target.String()) + return res, fmt.Errorf("%w: label matcher value %q conflicts with injected value %q", ErrIllegalLabelMatcher, matcher.String(), target.String()) } // Drop the existing matcher only if the enforced matcher is an diff --git a/injectproxy/enforce_test.go b/injectproxy/enforce_test.go index 59b5e915..1ca06cff 100644 --- a/injectproxy/enforce_test.go +++ b/injectproxy/enforce_test.go @@ -14,11 +14,11 @@ package injectproxy import ( + "errors" "fmt" "testing" "github.com/prometheus/prometheus/model/labels" - "github.com/prometheus/prometheus/promql/parser" ) type checkFunc func(expression string, err error) error @@ -34,32 +34,22 @@ func checks(cs ...checkFunc) checkFunc { } } -func hasError(want error) checkFunc { +func noError() checkFunc { return func(_ string, got error) error { - wantError, gotError := "", "" - - if want != nil { - wantError = fmt.Sprintf("%q", want.Error()) - } - if got != nil { - gotError = fmt.Sprintf("%q", got.Error()) - } - - if wantError != gotError { - return fmt.Errorf("want error %v, got %v", wantError, gotError) + return fmt.Errorf("want error , got %v", got) } return nil } } -func hasIllegalLabelMatcherError() checkFunc { +func errorIs(want error) checkFunc { return func(_ string, got error) error { - if _, ok := got.(IllegalLabelMatcherError); ok { + if errors.Is(got, want) { return nil } - return fmt.Errorf("want error of type IllegalLabelMatcherError, got %v", got) + return fmt.Errorf("want error of type %T, got %v", want, got) } } @@ -75,14 +65,14 @@ func hasExpression(want string) checkFunc { var tests = []struct { name string expression string - enforcer *Enforcer + enforcer *PromQLEnforcer check checkFunc }{ // first check correct label insertion { name: "expressions add label", expression: `round(metric1{label="baz"},3)`, - enforcer: NewEnforcer( + enforcer: NewPromQLEnforcer( false, &labels.Matcher{ Name: "namespace", @@ -96,7 +86,7 @@ var tests = []struct { }, ), check: checks( - hasError(nil), + noError(), hasExpression(`round(metric1{label="baz",namespace="NS",pod="POD"}, 3)`), ), }, @@ -104,7 +94,7 @@ var tests = []struct { { name: "aggregate add label", expression: `sum by (pod) (metric1{label="baz"})`, - enforcer: NewEnforcer( + enforcer: NewPromQLEnforcer( false, &labels.Matcher{ Name: "namespace", @@ -118,7 +108,7 @@ var tests = []struct { }, ), check: checks( - hasError(nil), + noError(), hasExpression(`sum by (pod) (metric1{label="baz",namespace="NS",pod="POD"})`), ), }, @@ -126,7 +116,7 @@ var tests = []struct { { name: "binary expression add label", expression: `metric1{} + sum by (pod) (metric2{label="baz"})`, - enforcer: NewEnforcer( + enforcer: NewPromQLEnforcer( false, &labels.Matcher{ Name: "namespace", @@ -140,7 +130,7 @@ var tests = []struct { }, ), check: checks( - hasError(nil), + noError(), hasExpression(`metric1{namespace="NS",pod="POD"} + sum by (pod) (metric2{label="baz",namespace="NS",pod="POD"})`), ), }, @@ -148,7 +138,7 @@ var tests = []struct { { name: "binary expression with vector matching add label", expression: `metric1{} + on(pod,namespace) sum by (pod) (metric2{label="baz"})`, - enforcer: NewEnforcer( + enforcer: NewPromQLEnforcer( false, &labels.Matcher{ Name: "namespace", @@ -162,7 +152,7 @@ var tests = []struct { }, ), check: checks( - hasError(nil), + noError(), hasExpression(`metric1{namespace="NS",pod="POD"} + on (pod, namespace) sum by (pod) (metric2{label="baz",namespace="NS",pod="POD"})`), ), }, @@ -171,7 +161,7 @@ var tests = []struct { { name: "expressions error on non-matching label value", expression: `round(metric1{label="baz",pod="POD",namespace="bar"},3)`, - enforcer: NewEnforcer( + enforcer: NewPromQLEnforcer( true, &labels.Matcher{ Name: "namespace", @@ -185,14 +175,14 @@ var tests = []struct { }, ), check: checks( - hasIllegalLabelMatcherError(), + errorIs(ErrIllegalLabelMatcher), ), }, { name: "aggregate error on non-matching label value", expression: `sum by (pod) (metric1{label="baz",pod="foo",namespace="bar"})`, - enforcer: NewEnforcer( + enforcer: NewPromQLEnforcer( true, &labels.Matcher{ Name: "namespace", @@ -206,14 +196,14 @@ var tests = []struct { }, ), check: checks( - hasIllegalLabelMatcherError(), + errorIs(ErrIllegalLabelMatcher), ), }, { name: "binary expression error on non-matching label value", expression: `metric1{pod="baz"} + sum by (pod) (metric2{label="baz",pod="foo",namespace="bar"})`, - enforcer: NewEnforcer( + enforcer: NewPromQLEnforcer( true, &labels.Matcher{ Name: "namespace", @@ -227,14 +217,14 @@ var tests = []struct { }, ), check: checks( - hasIllegalLabelMatcherError(), + errorIs(ErrIllegalLabelMatcher), ), }, { name: "binary expression with vector matching error on non-matching label value", expression: `metric1{pod="baz"} + on (pod,namespace) sum by (pod) (metric2{label="baz",pod="foo",namespace="bar"})`, - enforcer: NewEnforcer( + enforcer: NewPromQLEnforcer( true, &labels.Matcher{ Name: "namespace", @@ -248,7 +238,7 @@ var tests = []struct { }, ), check: checks( - hasIllegalLabelMatcherError(), + errorIs(ErrIllegalLabelMatcher), ), }, // and lastly check that passing the label matcher we would inject @@ -256,7 +246,7 @@ var tests = []struct { { name: "expressions unchanged with matching label value", expression: `round(metric1{label="baz",pod="POD",namespace="NS"},3)`, - enforcer: NewEnforcer( + enforcer: NewPromQLEnforcer( false, &labels.Matcher{ Name: "namespace", @@ -270,6 +260,7 @@ var tests = []struct { }, ), check: checks( + noError(), hasExpression(`round(metric1{label="baz",namespace="NS",pod="POD"}, 3)`), ), }, @@ -277,7 +268,7 @@ var tests = []struct { { name: "aggregate unchanged with matching label value", expression: `sum by (pod) (metric1{label="baz",pod="POD",namespace="NS"})`, - enforcer: NewEnforcer( + enforcer: NewPromQLEnforcer( false, &labels.Matcher{ Name: "namespace", @@ -291,6 +282,7 @@ var tests = []struct { }, ), check: checks( + noError(), hasExpression(`sum by (pod) (metric1{label="baz",namespace="NS",pod="POD"})`), ), }, @@ -298,7 +290,7 @@ var tests = []struct { { name: "binary expression unchanged with matching label value", expression: `metric1{pod="POD"} + sum by (pod) (metric2{label="baz",namespace="NS",pod="POD"})`, - enforcer: NewEnforcer( + enforcer: NewPromQLEnforcer( false, &labels.Matcher{ Name: "namespace", @@ -312,6 +304,7 @@ var tests = []struct { }, ), check: checks( + noError(), hasExpression(`metric1{namespace="NS",pod="POD"} + sum by (pod) (metric2{label="baz",namespace="NS",pod="POD"})`), ), }, @@ -319,7 +312,7 @@ var tests = []struct { { name: "binary expression with vector matching unchanged with matching label value", expression: `metric1{pod="POD"} + on (pod,namespace) sum by (pod) (metric2{label="baz",pod="POD",namespace="NS"})`, - enforcer: NewEnforcer( + enforcer: NewPromQLEnforcer( false, &labels.Matcher{ Name: "namespace", @@ -333,23 +326,34 @@ var tests = []struct { }, ), check: checks( + noError(), hasExpression(`metric1{namespace="NS",pod="POD"} + on (pod, namespace) sum by (pod) (metric2{label="baz",namespace="NS",pod="POD"})`), ), }, + { + name: "invalid PromQL expression", + expression: `metric1{pod="baz"`, + enforcer: NewPromQLEnforcer( + false, + &labels.Matcher{ + Name: "namespace", + Type: labels.MatchEqual, + Value: "NS", + }, + ), + check: checks( + errorIs(ErrQueryParse), + ), + }, } -func TestEnforceNode(t *testing.T) { +func TestEnforce(t *testing.T) { for _, tc := range tests { t.Run(tc.name, func(t *testing.T) { - e, err := parser.ParseExpr(tc.expression) - if err != nil { + got, err := tc.enforcer.Enforce(tc.expression) + if err := tc.check(got, err); err != nil { t.Fatal(err) } - - err = tc.enforcer.EnforceNode(e) - if err := tc.check(e.String(), err); err != nil { - t.Error(err) - } }) } } diff --git a/injectproxy/routes.go b/injectproxy/routes.go index 30e18a60..f851e073 100644 --- a/injectproxy/routes.go +++ b/injectproxy/routes.go @@ -16,6 +16,7 @@ package injectproxy import ( "context" "encoding/json" + "errors" "fmt" "io" "net/http" @@ -473,6 +474,7 @@ func (r *routes) query(w http.ResponseWriter, req *http.Request) { prometheusAPIError(w, "Only one label value allowed with regex match", http.StatusBadRequest) return } + matcher = &labels.Matcher{ Name: r.label, Type: labels.MatchRegexp, @@ -493,6 +495,7 @@ func (r *routes) query(w http.ResponseWriter, req *http.Request) { } matcherType = labels.MatchRegexp } + matcher = &labels.Matcher{ Name: r.label, Type: matcherType, @@ -500,7 +503,7 @@ func (r *routes) query(w http.ResponseWriter, req *http.Request) { } } - e := NewEnforcer(r.errorOnReplace, matcher) + e := NewPromQLEnforcer(r.errorOnReplace, matcher) // The `query` can come in the URL query string and/or the POST body. // For this reason, we need to try to enforcing in both places. @@ -509,14 +512,15 @@ func (r *routes) query(w http.ResponseWriter, req *http.Request) { // enforce in both places. q, found1, err := enforceQueryValues(e, req.URL.Query()) if err != nil { - switch err.(type) { - case IllegalLabelMatcherError: + switch { + case errors.Is(err, ErrIllegalLabelMatcher): prometheusAPIError(w, err.Error(), http.StatusBadRequest) - case queryParseError: + case errors.Is(err, ErrQueryParse): prometheusAPIError(w, err.Error(), http.StatusBadRequest) - case enforceLabelError: + case errors.Is(err, ErrEnforceLabel): prometheusAPIError(w, err.Error(), http.StatusInternalServerError) } + return } req.URL.RawQuery = q @@ -529,16 +533,18 @@ func (r *routes) query(w http.ResponseWriter, req *http.Request) { } q, found2, err = enforceQueryValues(e, req.PostForm) if err != nil { - switch err.(type) { - case IllegalLabelMatcherError: + switch { + case errors.Is(err, ErrIllegalLabelMatcher): prometheusAPIError(w, err.Error(), http.StatusBadRequest) - case queryParseError: + case errors.Is(err, ErrQueryParse): prometheusAPIError(w, err.Error(), http.StatusBadRequest) - case enforceLabelError: + case errors.Is(err, ErrEnforceLabel): prometheusAPIError(w, err.Error(), http.StatusInternalServerError) } + return } + // We are replacing request body, close previous one (ParseForm ensures it is read fully and not nil). _ = req.Body.Close() req.Body = io.NopCloser(strings.NewReader(q)) @@ -553,33 +559,29 @@ func (r *routes) query(w http.ResponseWriter, req *http.Request) { r.handler.ServeHTTP(w, req) } -func enforceQueryValues(e *Enforcer, v url.Values) (values string, noQuery bool, err error) { +func enforceQueryValues(e *PromQLEnforcer, v url.Values) (values string, noQuery bool, err error) { // If no values were given or no query is present, // e.g. because the query came in the POST body // but the URL query string was passed, then finish early. if v.Get(queryParam) == "" { return v.Encode(), false, nil } - expr, err := parser.ParseExpr(v.Get(queryParam)) + + q, err := e.Enforce(v.Get(queryParam)) if err != nil { - queryParseError := newQueryParseError(err) - return "", true, queryParseError + return "", true, err } - if err := e.EnforceNode(expr); err != nil { - if _, ok := err.(IllegalLabelMatcherError); ok { - return "", true, err - } - enforceLabelError := newEnforceLabelError(err) - return "", true, enforceLabelError - } + v.Set(queryParam, q) - v.Set(queryParam, expr.String()) return v.Encode(), true, nil } -// matcher ensures all the provided match[] if any has label injected. If none was provided, single matcher is injected. -// This works for non-query Prometheus APIs like: /api/v1/series, /api/v1/label//values, /api/v1/labels and /federate support multiple matchers. +// matcher modifies all the match[] HTTP parameters to match on the tenant label. +// If none was provided, a tenant label matcher matcher is injected. +// This works for non-query Prometheus API endpoints like /api/v1/series, +// /api/v1/label//values, /api/v1/labels and /federate which support +// multiple matchers. // See e.g https://prometheus.io/docs/prometheus/latest/querying/api/#querying-metadata func (r *routes) matcher(w http.ResponseWriter, req *http.Request) { matcher := &labels.Matcher{ @@ -643,30 +645,6 @@ func matchersToString(ms ...*labels.Matcher) string { return fmt.Sprintf("{%v}", strings.Join(el, ",")) } -type queryParseError struct { - msg string -} - -func (e queryParseError) Error() string { - return e.msg -} - -func newQueryParseError(err error) queryParseError { - return queryParseError{msg: fmt.Sprintf("error parsing query string %q", err.Error())} -} - -type enforceLabelError struct { - msg string -} - -func (e enforceLabelError) Error() string { - return e.msg -} - -func newEnforceLabelError(err error) enforceLabelError { - return enforceLabelError{msg: fmt.Sprintf("error enforcing label %q", err.Error())} -} - // humanFriendlyErrorMessage returns an error message with a capitalized first letter // and a punctuation at the end. func humanFriendlyErrorMessage(err error) string { diff --git a/main.go b/main.go index 4d3b05eb..c1412b51 100644 --- a/main.go +++ b/main.go @@ -124,9 +124,11 @@ func main() { if enableLabelAPIs { opts = append(opts, injectproxy.WithEnabledLabelsAPI()) } + if len(unsafePassthroughPaths) > 0 { opts = append(opts, injectproxy.WithPassthroughPaths(strings.Split(unsafePassthroughPaths, ","))) } + if errorOnReplace { opts = append(opts, injectproxy.WithErrorOnReplace()) } @@ -136,16 +138,19 @@ func main() { if len(labelValues) > 1 { log.Fatalf("Regex match is limited to one label value") } + compiledRegex, err := regexp.Compile(labelValues[0]) if err != nil { log.Fatalf("Invalid regexp: %v", err.Error()) return } + if compiledRegex.MatchString("") { log.Fatalf("Regex should not match empty string") return } } + opts = append(opts, injectproxy.WithRegexMatch()) }