Skip to content

Commit

Permalink
Forward headers to upstream
Browse files Browse the repository at this point in the history
  • Loading branch information
friedrichg committed Sep 22, 2023
1 parent 359183c commit cc489c6
Show file tree
Hide file tree
Showing 4 changed files with 68 additions and 0 deletions.
8 changes: 8 additions & 0 deletions client.go
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,14 @@ func (c client) Do(ctx context.Context, req *http.Request) (*http.Response, []by
}
req.URL.RawQuery = reqParams.Encode()
}
if header, ok := getForwardedHeader(ctx); ok {
if req.Header == nil {
req.Header = make(http.Header)
}
for key, fh := range header {
req.Header[key] = fh
}
}
if c.authz != "" {
if req.Header == nil {
req.Header = make(http.Header)
Expand Down
20 changes: 20 additions & 0 deletions flag.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
package main

import (
"strings"
)

// StringSliceVar is a custom type that implements the flag.Value interface
// to store a list of strings.
type StringSliceVar []string

// String returns a string representation of the StringSliceVar type.
func (ss *StringSliceVar) String() string {
return strings.Join(*ss, ", ")
}

// Set appends a value to the StringSliceVar.
func (ss *StringSliceVar) Set(value string) error {
*ss = append(*ss, value)
return nil
}
34 changes: 34 additions & 0 deletions header.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
package main

import (
"context"
"net/http"
)

type headerKey int

// addHeader adds forwarded headers to the context
func addForwardedHeader(ctx context.Context, h *http.Header, forwardHeaders *StringSliceVar) context.Context {
if forwardHeaders == nil {
return ctx
}
newH := make(http.Header)
for _, fh := range *forwardHeaders {
if values := (*h).Values(fh); values != nil {
for _, v := range values {
newH.Add(fh, v)
}
}
}
return context.WithValue(ctx, headerKey(0), newH)
}

// getForwardedHeader extracts from context the header
func getForwardedHeader(ctx context.Context) (http.Header, bool) {
if ctxValue := ctx.Value(headerKey(0)); ctxValue != nil {
if header, ok := ctxValue.(http.Header); ok {
return header, true
}
}
return nil, false
}
6 changes: 6 additions & 0 deletions main.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ var (
tlsSkipVerify bool
bearerFile string
forceGet bool
forwardHeaders StringSliceVar
)

func parseFlag() {
Expand All @@ -33,6 +34,7 @@ func parseFlag() {
flag.BoolVar(&tlsSkipVerify, "tlsSkipVerify", false, "Skip TLS Verification")
flag.StringVar(&bearerFile, "bearer-file", "", "File containing bearer token for API requests")
flag.BoolVar(&forceGet, "force-get", false, "Force api.Client to use GET by rejecting POST requests")
flag.Var(&forwardHeaders, "forward-header", "A header that will be forwarded to upstream")
flag.Parse()
}

Expand Down Expand Up @@ -79,6 +81,9 @@ func main() {
klog.Infof("Forcing api,Client to use GET requests")
options = append(options, withGet)
}
if forwardHeaders != nil {
klog.Infof("Following headers will be forwarded upstream: %v", forwardHeaders.String())
}
if c, err = newClient(c, options...); err != nil {
klog.Fatalf("error building custom API client:", err)
}
Expand Down Expand Up @@ -106,6 +111,7 @@ func federate(ctx context.Context, w http.ResponseWriter, r *http.Request, apiCl
if params.Del("match[]"); len(params) > 0 {
nctx = addValues(nctx, params)
}
nctx = addForwardedHeader(nctx, &r.Header, &forwardHeaders)
start := time.Now()
val, _, err := apiClient.Query(nctx, matchQuery, time.Now()) // Ignoring warnings for now.
responseTime := time.Since(start).Seconds()
Expand Down

0 comments on commit cc489c6

Please sign in to comment.