diff --git a/cmd/tokenizer/log.go b/cmd/tokenizer/log.go deleted file mode 100644 index 86c97f7..0000000 --- a/cmd/tokenizer/log.go +++ /dev/null @@ -1,48 +0,0 @@ -package main - -import ( - "net/http" - "strings" - "time" - - "github.com/sirupsen/logrus" - "golang.org/x/exp/maps" -) - -type responseWriter interface { - http.ResponseWriter - http.Hijacker -} - -type statusCodeRecorder struct { - responseWriter - statusCode int -} - -func (w *statusCodeRecorder) WriteHeader(statusCode int) { - w.statusCode = statusCode - w.responseWriter.WriteHeader(statusCode) -} - -func loggingMiddleware(next http.Handler) http.Handler { - return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - scw := &statusCodeRecorder{responseWriter: w.(responseWriter)} - queryKeys := strings.Join(maps.Keys(r.URL.Query()), ", ") - start := time.Now() - - next.ServeHTTP(scw, r) - - if scw.statusCode == 0 { - scw.statusCode = http.StatusOK - } - - logrus.WithFields(logrus.Fields{ - "method": r.Method, - "host": r.Host, - "path": r.URL.Path, - "queryKeys": string(queryKeys), - "status": scw.statusCode, - "durms": int64(time.Since(start) / time.Millisecond), - }).Info() - }) -} diff --git a/cmd/tokenizer/main.go b/cmd/tokenizer/main.go index 79989f4..87f6fb8 100644 --- a/cmd/tokenizer/main.go +++ b/cmd/tokenizer/main.go @@ -80,7 +80,12 @@ func runServe() { tkz := tokenizer.NewTokenizer(key) - server := &http.Server{Handler: loggingMiddleware(tkz)} + if len(os.Getenv("DEBUG")) != 0 { + tkz.ProxyHttpServer.Verbose = true + tkz.ProxyHttpServer.Logger = logrus.StandardLogger() + } + + server := &http.Server{Handler: tkz} go func() { if err := server.Serve(l); !errors.Is(err, http.ErrServerClosed) { diff --git a/request_validator.go b/request_validator.go index 6c9515e..8ce8ff2 100644 --- a/request_validator.go +++ b/request_validator.go @@ -1,7 +1,6 @@ package tokenizer import ( - "errors" "fmt" "net/http" "regexp" @@ -26,15 +25,11 @@ func AllowHosts(hosts ...string) RequestValidator { } func (v allowedHosts) Validate(r *http.Request) error { - host := r.URL.Host - if host == "" { - host = r.Host + if r.Host == "" { + return fmt.Errorf("%w: no host in request", ErrBadRequest) } - if host == "" { - return errors.New("coun't find host in request") - } - if _, allowed := v[host]; !allowed { - return fmt.Errorf("%w: secret not valid for %s", ErrBadRequest, host) + if _, allowed := v[r.Host]; !allowed { + return fmt.Errorf("%w: secret not valid for %s", ErrBadRequest, r.Host) } return nil } @@ -52,15 +47,11 @@ func AllowHostPattern(pattern *regexp.Regexp) RequestValidator { } func (v *allowedHostPattern) Validate(r *http.Request) error { - host := r.URL.Host - if host == "" { - host = r.Host - } - if host == "" { - return errors.New("coun't find host in request") + if r.Host == "" { + return fmt.Errorf("%w: no host in request", ErrBadRequest) } - if match := (*regexp.Regexp)(v).MatchString(host); !match { - return fmt.Errorf("%w: secret not valid for %s", ErrBadRequest, host) + if match := (*regexp.Regexp)(v).MatchString(r.Host); !match { + return fmt.Errorf("%w: secret not valid for %s", ErrBadRequest, r.Host) } return nil } diff --git a/tokenizer.go b/tokenizer.go index f86b6d7..d574fb1 100644 --- a/tokenizer.go +++ b/tokenizer.go @@ -13,12 +13,14 @@ import ( "net" "net/http" "strings" + "time" "unicode" "github.com/elazarl/goproxy" "github.com/sirupsen/logrus" "golang.org/x/crypto/curve25519" "golang.org/x/crypto/nacl/box" + "golang.org/x/exp/maps" ) var FilteredHeaders = []string{headerProxyAuthorization, headerProxyTokenizer} @@ -41,10 +43,6 @@ type tokenizer struct { pub *[32]byte } -var _ goproxy.HttpsHandler = (*tokenizer)(nil) -var _ goproxy.ReqHandler = (*tokenizer)(nil) -var _ http.Handler = new(tokenizer) - func NewTokenizer(openKey string) *tokenizer { privBytes, err := hex.DecodeString(openKey) if err != nil { @@ -71,8 +69,9 @@ func NewTokenizer(openKey string) *tokenizer { } proxy.ConnectDial = nil proxy.ConnectDialWithReq = nil - proxy.OnRequest().HandleConnect(tkz) - proxy.OnRequest().Do(tkz) + proxy.OnRequest().HandleConnectFunc(tkz.HandleConnect) + proxy.OnRequest().DoFunc(tkz.HandleRequest) + proxy.OnResponse().DoFunc(tkz.HandleResponse) return tkz } @@ -81,34 +80,76 @@ func (t *tokenizer) SealKey() string { return hex.EncodeToString(t.pub[:]) } -// HandleConnect implements goproxy.HttpsHandler +// data that we can pass around between callbacks +type proxyUserData struct { + // processors from our handling of the initial CONNECT request if this is a + // tunneled connection. + connectProcessors []RequestProcessor + + // start time of the CONNECT request if this is a tunneled connection. + connectStart time.Time + connLog logrus.FieldLogger + + // start time of the current request. gets reset between requests within a + // tunneled connection. + requestStart time.Time + reqLog logrus.FieldLogger +} + +// HandleConnect implements goproxy.FuncHttpsHandler func (t *tokenizer) HandleConnect(host string, ctx *goproxy.ProxyCtx) (*goproxy.ConnectAction, string) { + pud := &proxyUserData{ + connLog: logrus.WithField("connect_host", host), + connectStart: time.Now(), + } + _, port, _ := strings.Cut(host, ":") if port == "443" { - logrus.WithField("host", host).Warn("attempt to proxy to https downstream") + pud.connLog.Warn("attempt to proxy to https downstream") ctx.Resp = errorResponse(ErrBadRequest) return goproxy.RejectConnect, "" } - processors, err := t.processorsFromRequest(ctx.Req) - if err != nil { + var err error + if pud.connectProcessors, err = t.processorsFromRequest(ctx.Req); err != nil { + pud.connLog.WithError(err).Warn("find processor (CONNECT)") ctx.Resp = errorResponse(err) return goproxy.RejectConnect, "" } - ctx.UserData = processors + ctx.UserData = pud + return goproxy.HTTPMitmConnect, host } -// Handle implements goproxy.FuncReqHandler -func (t *tokenizer) Handle(req *http.Request, ctx *goproxy.ProxyCtx) (*http.Request, *http.Response) { - var processors []RequestProcessor - if ctx.UserData != nil { - processors = ctx.UserData.([]RequestProcessor) +// HandleRequest implements goproxy.FuncReqHandler +func (t *tokenizer) HandleRequest(req *http.Request, ctx *goproxy.ProxyCtx) (*http.Request, *http.Response) { + if ctx.UserData == nil { + ctx.UserData = &proxyUserData{} + } + pud, ok := ctx.UserData.(*proxyUserData) + + if !ok || !pud.requestStart.IsZero() || pud.reqLog != nil { + logrus.Warn("bad proxyUserData") + return nil, errorResponse(ErrInternal) + } + + pud.requestStart = time.Now() + if pud.connLog != nil { + pud.reqLog = pud.connLog + } else { + pud.reqLog = logrus.StandardLogger() } + pud.reqLog = pud.reqLog.WithFields(logrus.Fields{ + "method": req.Method, + "host": req.Host, + "path": req.URL.Path, + "queryKeys": strings.Join(maps.Keys(req.URL.Query()), ", "), + }) + processors := append([]RequestProcessor(nil), pud.connectProcessors...) if reqProcessors, err := t.processorsFromRequest(req); err != nil { - logrus.WithError(err).Warn("find processor") + pud.reqLog.WithError(err).Warn("find processor") return req, errorResponse(err) } else { processors = append(processors, reqProcessors...) @@ -116,7 +157,7 @@ func (t *tokenizer) Handle(req *http.Request, ctx *goproxy.ProxyCtx) (*http.Requ for _, processor := range processors { if err := processor(req); err != nil { - logrus.WithError(err).Warn("run processor") + pud.reqLog.WithError(err).Warn("run processor") return nil, errorResponse(ErrBadRequest) } } @@ -128,6 +169,47 @@ func (t *tokenizer) Handle(req *http.Request, ctx *goproxy.ProxyCtx) (*http.Requ return req, nil } +// HandleResponse implements goproxy.FuncRespHandler +func (t *tokenizer) HandleResponse(resp *http.Response, ctx *goproxy.ProxyCtx) *http.Response { + // This callback is hit twice if there was an error in the downstream + // request. The first time a nil request is given and the second time we're + // given whatever we returned the first time. Skip logging on the second + // call. This should continue to work okay if + // https://github.com/elazarl/goproxy/pull/512 is ever merged. + if ctx.Error != nil && resp != nil { + return resp + } + + pud, ok := ctx.UserData.(*proxyUserData) + if !ok || pud.requestStart.IsZero() || pud.reqLog == nil { + logrus.Warn("missing proxyUserData") + return errorResponse(ErrInternal) + } + + log := pud.reqLog.WithField("durMS", int64(time.Since(pud.requestStart)/time.Millisecond)) + + if !pud.connectStart.IsZero() { + log = log.WithField("connDurMS", int64(time.Since(pud.connectStart)/time.Millisecond)) + } + if resp != nil { + log = log.WithField("status", resp.StatusCode) + resp.Header.Set("Connection", "close") + } + + // reset pud for next request in tunnel + pud.requestStart = time.Time{} + pud.reqLog = nil + + if ctx.Error != nil { + log.WithError(ctx.Error).Warn() + return errorResponse(ctx.Error) + } + + log.Info() + + return resp +} + func (t *tokenizer) processorsFromRequest(req *http.Request) ([]RequestProcessor, error) { hdrs := req.Header[headerProxyTokenizer] processors := make([]RequestProcessor, 0, len(hdrs)) @@ -215,13 +297,20 @@ func errorResponse(err error) *http.Response { status = http.StatusBadGateway } - return &http.Response{StatusCode: status, Body: io.NopCloser(bytes.NewReader([]byte(err.Error())))} + return &http.Response{ + StatusCode: status, + Body: io.NopCloser(bytes.NewReader([]byte(err.Error()))), + Header: make(http.Header), + } } func forceTLSDialer(network, addr string) (net.Conn, error) { - if network != "tcp" { + switch network { + case "tcp", "tcp4", "tcp6": + default: return nil, fmt.Errorf("%w: dialing network %s not supported", ErrBadRequest, network) } + hostname, port, _ := strings.Cut(addr, ":") if hostname == "" { return nil, fmt.Errorf("%w: attempt to dial without host: %q", ErrBadRequest, addr) @@ -233,5 +322,6 @@ func forceTLSDialer(network, addr string) (net.Conn, error) { port = "443" } addr = fmt.Sprintf("%s:%s", hostname, port) + return tls.Dial("tcp", addr, &tls.Config{RootCAs: upstreamTrust}) } diff --git a/tokenizer_test.go b/tokenizer_test.go index d65a1ec..6de5a66 100644 --- a/tokenizer_test.go +++ b/tokenizer_test.go @@ -56,7 +56,7 @@ func TestTokenizer(t *testing.T) { // TLS error (proxy doesn't trust upstream) resp, err := client.Get(appURL) assert.NoError(t, err) - assert.Equal(t, http.StatusInternalServerError, resp.StatusCode) + assert.Equal(t, http.StatusBadGateway, resp.StatusCode) // make proxy trust upstream upstreamTrust.AddCert(appServer.Certificate())