Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Forward headers to upstream #44

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions client.go
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,14 @@ func (c client) Do(ctx context.Context, req *http.Request) (*http.Response, []by
}
req.URL.RawQuery = reqParams.Encode()
}
if header, ok := getForwardedHeader(ctx); ok {
if req.Header == nil {
req.Header = make(http.Header)
}
for key, fh := range header {
req.Header[key] = fh
}
}
if c.authz != "" {
if req.Header == nil {
req.Header = make(http.Header)
Expand Down
20 changes: 20 additions & 0 deletions flag.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
package main

import (
"strings"
)

// StringSliceVar is a custom type that implements the flag.Value interface
// to store a list of strings.
type StringSliceVar []string

// String returns a string representation of the StringSliceVar type.
func (ss *StringSliceVar) String() string {
return strings.Join(*ss, ", ")
}

// Set appends a value to the StringSliceVar.
func (ss *StringSliceVar) Set(value string) error {
*ss = append(*ss, value)
return nil
}
34 changes: 34 additions & 0 deletions header.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
package main

import (
"context"
"net/http"
)

type headerKey int

// addHeader adds forwarded headers to the context
func addForwardedHeader(ctx context.Context, h *http.Header, forwardHeaders *StringSliceVar) context.Context {
if forwardHeaders == nil {
return ctx
}
newH := make(http.Header)
for _, fh := range *forwardHeaders {
if values := (*h).Values(fh); values != nil {
for _, v := range values {
newH.Add(fh, v)
}
}
}
return context.WithValue(ctx, headerKey(0), newH)
}

// getForwardedHeader extracts from context the header
func getForwardedHeader(ctx context.Context) (http.Header, bool) {
if ctxValue := ctx.Value(headerKey(0)); ctxValue != nil {
if header, ok := ctxValue.(http.Header); ok {
return header, true
}
}
return nil, false
}
6 changes: 6 additions & 0 deletions main.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ var (
tlsSkipVerify bool
bearerFile string
forceGet bool
forwardHeaders StringSliceVar
)

func parseFlag() {
Expand All @@ -33,6 +34,7 @@ func parseFlag() {
flag.BoolVar(&tlsSkipVerify, "tlsSkipVerify", false, "Skip TLS Verification")
flag.StringVar(&bearerFile, "bearer-file", "", "File containing bearer token for API requests")
flag.BoolVar(&forceGet, "force-get", false, "Force api.Client to use GET by rejecting POST requests")
flag.Var(&forwardHeaders, "forward-header", "A header that will be forwarded to upstream")
flag.Parse()
}

Expand Down Expand Up @@ -79,6 +81,9 @@ func main() {
klog.Infof("Forcing api,Client to use GET requests")
options = append(options, withGet)
}
if forwardHeaders != nil {
klog.Infof("Following headers will be forwarded upstream: %v", forwardHeaders.String())
}
if c, err = newClient(c, options...); err != nil {
klog.Fatalf("error building custom API client:", err)
}
Expand All @@ -105,6 +110,7 @@ func federate(ctx context.Context, w http.ResponseWriter, r *http.Request, apiCl
if params.Del("match[]"); len(params) > 0 {
nctx = addValues(nctx, params)
}
nctx = addForwardedHeader(nctx, &r.Header, &forwardHeaders)
for _, matchQuery := range matchQueries {
start := time.Now()
// Ignoring warnings for now.
Expand Down