diff --git a/.gitignore b/.gitignore index 6671f09..209169d 100644 --- a/.gitignore +++ b/.gitignore @@ -1,3 +1,4 @@ *.pem /bin/* /.envrc +.vscode diff --git a/client.go b/client.go index f51a991..4ff779f 100644 --- a/client.go +++ b/client.go @@ -7,6 +7,7 @@ import ( "fmt" "net/http" "net/url" + "strings" "golang.org/x/exp/maps" "golang.org/x/exp/slices" @@ -22,34 +23,70 @@ func init() { } } -type ClientOption http.Header +type ClientOption func(*clientOptions) + +type clientOptions struct { + headers http.Header + transport *http.Transport +} + +func (co *clientOptions) getTransport() *http.Transport { + if co.transport != nil { + return co.transport.Clone() + } + return http.DefaultTransport.(*http.Transport).Clone() +} func WithSecret(sealedSecret string, params map[string]string) ClientOption { - return ClientOption(http.Header{headerProxyTokenizer: {formatHeaderProxyTokenizer(sealedSecret, params)}}) + return func(co *clientOptions) { + if co.headers == nil { + co.headers = make(http.Header) + } + co.headers.Add(headerProxyTokenizer, formatHeaderProxyTokenizer(sealedSecret, params)) + } } func WithAuth(auth string) ClientOption { - return ClientOption(http.Header{headerProxyAuthorization: {fmt.Sprintf("Bearer %s", auth)}}) + return func(co *clientOptions) { + if co.headers == nil { + co.headers = make(http.Header) + } + co.headers.Add(headerProxyAuthorization, fmt.Sprintf("Bearer %s", auth)) + } +} + +func WithTransport(t *http.Transport) ClientOption { + return func(co *clientOptions) { + co.transport = t + } } func Client(proxyURL string, opts ...ClientOption) (*http.Client, error) { + t, err := Transport(proxyURL, opts...) + if err != nil { + return nil, err + } + + return &http.Client{Transport: t}, nil +} + +func Transport(proxyURL string, opts ...ClientOption) (http.RoundTripper, error) { u, err := url.Parse(proxyURL) if err != nil { return nil, err } - hdrs := make(http.Header, len(opts)) + copts := &clientOptions{} for _, o := range opts { - mergeHeader(hdrs, o) + o(copts) } - t := headerInjector(&http.Transport{ - Proxy: http.ProxyURL(u), - TLSClientConfig: &tls.Config{RootCAs: downstreamTrust}, - ForceAttemptHTTP2: true, - }, hdrs) + t := copts.getTransport() + t.Proxy = http.ProxyURL(u) + t.TLSClientConfig = &tls.Config{RootCAs: downstreamTrust} + // t.ForceAttemptHTTP2 = true - return &http.Client{Transport: t}, nil + return forceHTTP(headerInjector(t, copts.headers)), nil } type roundTripperFunc func(*http.Request) (*http.Response, error) @@ -58,6 +95,16 @@ func (rtf roundTripperFunc) RoundTrip(req *http.Request) (*http.Response, error) return rtf(req) } +func forceHTTP(t http.RoundTripper) http.RoundTripper { + return roundTripperFunc(func(r *http.Request) (*http.Response, error) { + if strings.EqualFold(r.URL.Scheme, "https") { + r = r.Clone(r.Context()) + r.URL.Scheme = "http" + } + return t.RoundTrip(r) + }) +} + type ctxKey string var ctxKeyInjected ctxKey = "injected" diff --git a/tokenizer_test.go b/tokenizer_test.go index a9ba123..b2f7c1e 100644 --- a/tokenizer_test.go +++ b/tokenizer_test.go @@ -122,8 +122,10 @@ func TestTokenizer(t *testing.T) { assert.NoError(t, err) creq, err := http.NewRequest(http.MethodConnect, appURL, nil) assert.NoError(t, err) - mergeHeader(creq.Header, WithAuth(siAuth)) - mergeHeader(creq.Header, WithSecret(si, nil)) + opts := clientOptions{} + WithAuth(siAuth)(&opts) + WithSecret(si, nil)(&opts) + mergeHeader(creq.Header, opts.headers) assert.NoError(t, creq.Write(conn)) resp, err = http.ReadResponse(connreader, creq) assert.NoError(t, err)