Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

More proxy improvements #20

Merged
merged 2 commits into from
Oct 4, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 0 additions & 48 deletions cmd/tokenizer/log.go

This file was deleted.

7 changes: 6 additions & 1 deletion cmd/tokenizer/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,12 @@ func runServe() {

tkz := tokenizer.NewTokenizer(key)

server := &http.Server{Handler: loggingMiddleware(tkz)}
if len(os.Getenv("DEBUG")) != 0 {
tkz.ProxyHttpServer.Verbose = true
tkz.ProxyHttpServer.Logger = logrus.StandardLogger()
}

server := &http.Server{Handler: tkz}

go func() {
if err := server.Serve(l); !errors.Is(err, http.ErrServerClosed) {
Expand Down
25 changes: 8 additions & 17 deletions request_validator.go
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
package tokenizer

import (
"errors"
"fmt"
"net/http"
"regexp"
Expand All @@ -26,15 +25,11 @@ func AllowHosts(hosts ...string) RequestValidator {
}

func (v allowedHosts) Validate(r *http.Request) error {
host := r.URL.Host
if host == "" {
host = r.Host
if r.Host == "" {
return fmt.Errorf("%w: no host in request", ErrBadRequest)
}
if host == "" {
return errors.New("coun't find host in request")
}
if _, allowed := v[host]; !allowed {
return fmt.Errorf("%w: secret not valid for %s", ErrBadRequest, host)
if _, allowed := v[r.Host]; !allowed {
return fmt.Errorf("%w: secret not valid for %s", ErrBadRequest, r.Host)
}
return nil
}
Expand All @@ -52,15 +47,11 @@ func AllowHostPattern(pattern *regexp.Regexp) RequestValidator {
}

func (v *allowedHostPattern) Validate(r *http.Request) error {
host := r.URL.Host
if host == "" {
host = r.Host
}
if host == "" {
return errors.New("coun't find host in request")
if r.Host == "" {
return fmt.Errorf("%w: no host in request", ErrBadRequest)
}
if match := (*regexp.Regexp)(v).MatchString(host); !match {
return fmt.Errorf("%w: secret not valid for %s", ErrBadRequest, host)
if match := (*regexp.Regexp)(v).MatchString(r.Host); !match {
return fmt.Errorf("%w: secret not valid for %s", ErrBadRequest, r.Host)
}
return nil
}
130 changes: 110 additions & 20 deletions tokenizer.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,14 @@ import (
"net"
"net/http"
"strings"
"time"
"unicode"

"github.com/elazarl/goproxy"
"github.com/sirupsen/logrus"
"golang.org/x/crypto/curve25519"
"golang.org/x/crypto/nacl/box"
"golang.org/x/exp/maps"
)

var FilteredHeaders = []string{headerProxyAuthorization, headerProxyTokenizer}
Expand All @@ -41,10 +43,6 @@ type tokenizer struct {
pub *[32]byte
}

var _ goproxy.HttpsHandler = (*tokenizer)(nil)
var _ goproxy.ReqHandler = (*tokenizer)(nil)
var _ http.Handler = new(tokenizer)

func NewTokenizer(openKey string) *tokenizer {
privBytes, err := hex.DecodeString(openKey)
if err != nil {
Expand All @@ -71,8 +69,9 @@ func NewTokenizer(openKey string) *tokenizer {
}
proxy.ConnectDial = nil
proxy.ConnectDialWithReq = nil
proxy.OnRequest().HandleConnect(tkz)
proxy.OnRequest().Do(tkz)
proxy.OnRequest().HandleConnectFunc(tkz.HandleConnect)
proxy.OnRequest().DoFunc(tkz.HandleRequest)
proxy.OnResponse().DoFunc(tkz.HandleResponse)

return tkz
}
Expand All @@ -81,42 +80,84 @@ func (t *tokenizer) SealKey() string {
return hex.EncodeToString(t.pub[:])
}

// HandleConnect implements goproxy.HttpsHandler
// data that we can pass around between callbacks
type proxyUserData struct {
// processors from our handling of the initial CONNECT request if this is a
// tunneled connection.
connectProcessors []RequestProcessor

// start time of the CONNECT request if this is a tunneled connection.
connectStart time.Time
connLog logrus.FieldLogger

// start time of the current request. gets reset between requests within a
// tunneled connection.
requestStart time.Time
reqLog logrus.FieldLogger
}

// HandleConnect implements goproxy.FuncHttpsHandler
func (t *tokenizer) HandleConnect(host string, ctx *goproxy.ProxyCtx) (*goproxy.ConnectAction, string) {
pud := &proxyUserData{
connLog: logrus.WithField("connect_host", host),
connectStart: time.Now(),
}

_, port, _ := strings.Cut(host, ":")
if port == "443" {
logrus.WithField("host", host).Warn("attempt to proxy to https downstream")
pud.connLog.Warn("attempt to proxy to https downstream")
ctx.Resp = errorResponse(ErrBadRequest)
return goproxy.RejectConnect, ""
}

processors, err := t.processorsFromRequest(ctx.Req)
if err != nil {
var err error
if pud.connectProcessors, err = t.processorsFromRequest(ctx.Req); err != nil {
pud.connLog.WithError(err).Warn("find processor (CONNECT)")
ctx.Resp = errorResponse(err)
return goproxy.RejectConnect, ""
}

ctx.UserData = processors
ctx.UserData = pud

return goproxy.HTTPMitmConnect, host
}

// Handle implements goproxy.FuncReqHandler
func (t *tokenizer) Handle(req *http.Request, ctx *goproxy.ProxyCtx) (*http.Request, *http.Response) {
var processors []RequestProcessor
if ctx.UserData != nil {
processors = ctx.UserData.([]RequestProcessor)
// HandleRequest implements goproxy.FuncReqHandler
func (t *tokenizer) HandleRequest(req *http.Request, ctx *goproxy.ProxyCtx) (*http.Request, *http.Response) {
if ctx.UserData == nil {
ctx.UserData = &proxyUserData{}
}
pud, ok := ctx.UserData.(*proxyUserData)

if !ok || !pud.requestStart.IsZero() || pud.reqLog != nil {
logrus.Warn("bad proxyUserData")
return nil, errorResponse(ErrInternal)
}

pud.requestStart = time.Now()
if pud.connLog != nil {
pud.reqLog = pud.connLog
} else {
pud.reqLog = logrus.StandardLogger()
}
pud.reqLog = pud.reqLog.WithFields(logrus.Fields{
"method": req.Method,
"host": req.Host,
"path": req.URL.Path,
"queryKeys": strings.Join(maps.Keys(req.URL.Query()), ", "),
})

processors := append([]RequestProcessor(nil), pud.connectProcessors...)
if reqProcessors, err := t.processorsFromRequest(req); err != nil {
logrus.WithError(err).Warn("find processor")
pud.reqLog.WithError(err).Warn("find processor")
return req, errorResponse(err)
} else {
processors = append(processors, reqProcessors...)
}

for _, processor := range processors {
if err := processor(req); err != nil {
logrus.WithError(err).Warn("run processor")
pud.reqLog.WithError(err).Warn("run processor")
return nil, errorResponse(ErrBadRequest)
}
}
Expand All @@ -128,6 +169,47 @@ func (t *tokenizer) Handle(req *http.Request, ctx *goproxy.ProxyCtx) (*http.Requ
return req, nil
}

// HandleResponse implements goproxy.FuncRespHandler
func (t *tokenizer) HandleResponse(resp *http.Response, ctx *goproxy.ProxyCtx) *http.Response {
// This callback is hit twice if there was an error in the downstream
// request. The first time a nil request is given and the second time we're
// given whatever we returned the first time. Skip logging on the second
// call. This should continue to work okay if
// https://github.com/elazarl/goproxy/pull/512 is ever merged.
if ctx.Error != nil && resp != nil {
return resp
}

pud, ok := ctx.UserData.(*proxyUserData)
if !ok || pud.requestStart.IsZero() || pud.reqLog == nil {
logrus.Warn("missing proxyUserData")
return errorResponse(ErrInternal)
}

log := pud.reqLog.WithField("durMS", int64(time.Since(pud.requestStart)/time.Millisecond))

if !pud.connectStart.IsZero() {
log = log.WithField("connDurMS", int64(time.Since(pud.connectStart)/time.Millisecond))
}
if resp != nil {
log = log.WithField("status", resp.StatusCode)
resp.Header.Set("Connection", "close")
}

// reset pud for next request in tunnel
pud.requestStart = time.Time{}
pud.reqLog = nil

if ctx.Error != nil {
log.WithError(ctx.Error).Warn()
return errorResponse(ctx.Error)
}

log.Info()

return resp
}

func (t *tokenizer) processorsFromRequest(req *http.Request) ([]RequestProcessor, error) {
hdrs := req.Header[headerProxyTokenizer]
processors := make([]RequestProcessor, 0, len(hdrs))
Expand Down Expand Up @@ -215,13 +297,20 @@ func errorResponse(err error) *http.Response {
status = http.StatusBadGateway
}

return &http.Response{StatusCode: status, Body: io.NopCloser(bytes.NewReader([]byte(err.Error())))}
return &http.Response{
StatusCode: status,
Body: io.NopCloser(bytes.NewReader([]byte(err.Error()))),
Header: make(http.Header),
}
}

func forceTLSDialer(network, addr string) (net.Conn, error) {
if network != "tcp" {
switch network {
case "tcp", "tcp4", "tcp6":
default:
return nil, fmt.Errorf("%w: dialing network %s not supported", ErrBadRequest, network)
}

hostname, port, _ := strings.Cut(addr, ":")
if hostname == "" {
return nil, fmt.Errorf("%w: attempt to dial without host: %q", ErrBadRequest, addr)
Expand All @@ -233,5 +322,6 @@ func forceTLSDialer(network, addr string) (net.Conn, error) {
port = "443"
}
addr = fmt.Sprintf("%s:%s", hostname, port)

return tls.Dial("tcp", addr, &tls.Config{RootCAs: upstreamTrust})
}
2 changes: 1 addition & 1 deletion tokenizer_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ func TestTokenizer(t *testing.T) {
// TLS error (proxy doesn't trust upstream)
resp, err := client.Get(appURL)
assert.NoError(t, err)
assert.Equal(t, http.StatusInternalServerError, resp.StatusCode)
assert.Equal(t, http.StatusBadGateway, resp.StatusCode)

// make proxy trust upstream
upstreamTrust.AddCert(appServer.Certificate())
Expand Down