Skip to content

Commit

Permalink
client helper for getting an http.RoundTripper (#16)
Browse files Browse the repository at this point in the history
  • Loading branch information
btoews authored Sep 26, 2023
1 parent c001cc8 commit ef2652d
Show file tree
Hide file tree
Showing 3 changed files with 63 additions and 13 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
*.pem
/bin/*
/.envrc
.vscode
69 changes: 58 additions & 11 deletions client.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import (
"fmt"
"net/http"
"net/url"
"strings"

"golang.org/x/exp/maps"
"golang.org/x/exp/slices"
Expand All @@ -22,34 +23,70 @@ func init() {
}
}

type ClientOption http.Header
type ClientOption func(*clientOptions)

type clientOptions struct {
headers http.Header
transport *http.Transport
}

func (co *clientOptions) getTransport() *http.Transport {
if co.transport != nil {
return co.transport.Clone()
}
return http.DefaultTransport.(*http.Transport).Clone()
}

func WithSecret(sealedSecret string, params map[string]string) ClientOption {
return ClientOption(http.Header{headerProxyTokenizer: {formatHeaderProxyTokenizer(sealedSecret, params)}})
return func(co *clientOptions) {
if co.headers == nil {
co.headers = make(http.Header)
}
co.headers.Add(headerProxyTokenizer, formatHeaderProxyTokenizer(sealedSecret, params))
}
}

func WithAuth(auth string) ClientOption {
return ClientOption(http.Header{headerProxyAuthorization: {fmt.Sprintf("Bearer %s", auth)}})
return func(co *clientOptions) {
if co.headers == nil {
co.headers = make(http.Header)
}
co.headers.Add(headerProxyAuthorization, fmt.Sprintf("Bearer %s", auth))
}
}

func WithTransport(t *http.Transport) ClientOption {
return func(co *clientOptions) {
co.transport = t
}
}

func Client(proxyURL string, opts ...ClientOption) (*http.Client, error) {
t, err := Transport(proxyURL, opts...)
if err != nil {
return nil, err
}

return &http.Client{Transport: t}, nil
}

func Transport(proxyURL string, opts ...ClientOption) (http.RoundTripper, error) {
u, err := url.Parse(proxyURL)
if err != nil {
return nil, err
}

hdrs := make(http.Header, len(opts))
copts := &clientOptions{}
for _, o := range opts {
mergeHeader(hdrs, o)
o(copts)
}

t := headerInjector(&http.Transport{
Proxy: http.ProxyURL(u),
TLSClientConfig: &tls.Config{RootCAs: downstreamTrust},
ForceAttemptHTTP2: true,
}, hdrs)
t := copts.getTransport()
t.Proxy = http.ProxyURL(u)
t.TLSClientConfig = &tls.Config{RootCAs: downstreamTrust}
// t.ForceAttemptHTTP2 = true

return &http.Client{Transport: t}, nil
return forceHTTP(headerInjector(t, copts.headers)), nil
}

type roundTripperFunc func(*http.Request) (*http.Response, error)
Expand All @@ -58,6 +95,16 @@ func (rtf roundTripperFunc) RoundTrip(req *http.Request) (*http.Response, error)
return rtf(req)
}

func forceHTTP(t http.RoundTripper) http.RoundTripper {
return roundTripperFunc(func(r *http.Request) (*http.Response, error) {
if strings.EqualFold(r.URL.Scheme, "https") {
r = r.Clone(r.Context())
r.URL.Scheme = "http"
}
return t.RoundTrip(r)
})
}

type ctxKey string

var ctxKeyInjected ctxKey = "injected"
Expand Down
6 changes: 4 additions & 2 deletions tokenizer_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -122,8 +122,10 @@ func TestTokenizer(t *testing.T) {
assert.NoError(t, err)
creq, err := http.NewRequest(http.MethodConnect, appURL, nil)
assert.NoError(t, err)
mergeHeader(creq.Header, WithAuth(siAuth))
mergeHeader(creq.Header, WithSecret(si, nil))
opts := clientOptions{}
WithAuth(siAuth)(&opts)
WithSecret(si, nil)(&opts)
mergeHeader(creq.Header, opts.headers)
assert.NoError(t, creq.Write(conn))
resp, err = http.ReadResponse(connreader, creq)
assert.NoError(t, err)
Expand Down

0 comments on commit ef2652d

Please sign in to comment.