Skip to content

Commit

Permalink
Merge pull request #81 from 2manymws/refactor-io
Browse files Browse the repository at this point in the history
Avoid getting request body as much as possible in the handler.
  • Loading branch information
k1LoW authored Jun 24, 2024
2 parents 1439b7d + adb3e41 commit 3781590
Show file tree
Hide file tree
Showing 3 changed files with 82 additions and 64 deletions.
7 changes: 0 additions & 7 deletions .golangci.yml
Original file line number Diff line number Diff line change
@@ -1,11 +1,4 @@
run:
go: 1.21
modules-download-mode: mod
linters:
fast: false
linters-settings:
staticcheck:
go: 1.16
issues:
exclude:
- SA3000
18 changes: 18 additions & 0 deletions copybuf.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
package rc

import "sync"

// Copy from net/http/server.go
const copyBufPoolSize = 32 * 1024

var copyBufPool = sync.Pool{New: func() any { return new([copyBufPoolSize]byte) }}

func getCopyBuf() []byte { //nostyle:getters
return copyBufPool.Get().(*[copyBufPoolSize]byte)[:]
}
func putCopyBuf(b []byte) {
if len(b) != copyBufPoolSize {
panic("trying to put back buffer of the wrong size in the copyBufPool") //nostyle:dontpanic
}
copyBufPool.Put((*[copyBufPoolSize]byte)(b))
}
121 changes: 64 additions & 57 deletions rc.go
Original file line number Diff line number Diff line change
Expand Up @@ -95,27 +95,28 @@ func (m *cacheMw) Handler(next http.Handler) http.Handler {
now := time.Now()

// Copy the request so that it is not affected by the next handler.
req, preq := m.duplicateRequest(req)
// reqc is the request to be used for caching.
req, reqc := m.duplicateRequest(req)

cachedReq, cachedRes, err := m.cacher.Load(preq) //nostyle:handlerrors
cachedReq, cachedRes, err := m.cacher.Load(reqc) //nostyle:handlerrors
if err != nil {
switch {
case errors.Is(err, ErrCacheNotFound):
m.logger.Debug("cache not found", slog.String("host", preq.Host), slog.String("method", preq.Method), slog.String("url", preq.URL.String()), slog.Any("headers", m.maskHeader(preq.Header)))
m.logger.Debug("cache not found", slog.String("host", reqc.Host), slog.String("method", reqc.Method), slog.String("url", reqc.URL.String()), slog.Any("headers", m.maskHeader(reqc.Header)))
case errors.Is(err, ErrCacheExpired):
m.logger.Debug("cache expired", slog.String("host", preq.Host), slog.String("method", preq.Method), slog.String("url", preq.URL.String()), slog.Any("headers", m.maskHeader(preq.Header)))
m.logger.Debug("cache expired", slog.String("host", reqc.Host), slog.String("method", reqc.Method), slog.String("url", reqc.URL.String()), slog.Any("headers", m.maskHeader(reqc.Header)))
case errors.Is(err, ErrShouldNotUseCache):
m.logger.Debug("should not use cache", slog.String("host", preq.Host), slog.String("method", preq.Method), slog.String("url", preq.URL.String()), slog.Any("headers", m.maskHeader(preq.Header)))
m.logger.Debug("should not use cache", slog.String("host", reqc.Host), slog.String("method", reqc.Method), slog.String("url", reqc.URL.String()), slog.Any("headers", m.maskHeader(reqc.Header)))
// Skip caching
next.ServeHTTP(w, req)
return
default:
m.logger.Error("failed to load cache", slog.String("error", err.Error()), slog.String("host", preq.Host), slog.String("method", preq.Method), slog.String("url", preq.URL.String()), slog.Any("headers", m.maskHeader(preq.Header)))
m.logger.Error("failed to load cache", slog.String("error", err.Error()), slog.String("host", reqc.Host), slog.String("method", reqc.Method), slog.String("url", reqc.URL.String()), slog.Any("headers", m.maskHeader(reqc.Header)))
}
}
cacheUsed, res, err := m.cacher.Handle(req, cachedReq, cachedRes, HandlerToRequester(next), now) //nostyle:handlerrors
cacheUsed, res, err := m.cacher.Handle(req, cachedReq, cachedRes, m.handlerToRequester(next, reqc, now), now) //nostyle:handlerrors
if err != nil {
m.logger.Error("failed to handle cache", slog.String("error", err.Error()), slog.String("host", preq.Host), slog.String("method", preq.Method), slog.String("url", preq.URL.String()), slog.Any("headers", m.maskHeader(preq.Header)))
m.logger.Error("failed to handle cache", slog.String("error", err.Error()), slog.String("host", reqc.Host), slog.String("method", reqc.Method), slog.String("url", reqc.URL.String()), slog.Any("headers", m.maskHeader(reqc.Header)))
}

// Response
Expand All @@ -131,50 +132,35 @@ func (m *cacheMw) Handler(next http.Handler) http.Handler {
}
}
w.WriteHeader(res.StatusCode)
body, err := io.ReadAll(res.Body)
if err != nil {
m.logger.Error("failed to read response body", slog.String("error", err.Error()), slog.String("host", preq.Host), slog.String("method", preq.Method), slog.String("url", preq.URL.String()), slog.Any("headers", m.maskHeader(preq.Header)), slog.Int("status", res.StatusCode), slog.Any("response_headers", m.maskHeader(res.Header)))
} else {
if _, err := w.Write(body); err != nil {
// Error as debug
// - os.ErrDeadlineExceeded: The request context has been canceled or has expired.
// - "client disconnected": The client disconnected. (net/http.http2errClientDisconnected)
// - "http2: stream closed": The client disconnected. (net/http.http2errStreamClosed)
// - syscall.ECONNRESET: The client disconnected. ("connection reset by peer")
// - syscall.EPIPE: The client disconnected. ("broken pipe")
// - http.ErrBodyNotAllowed: The request method does not allow a body.
switch {
case errors.Is(err, os.ErrDeadlineExceeded) || errors.Is(err, syscall.ECONNRESET) || errors.Is(err, syscall.EPIPE) || contains([]string{"client disconnected", "http2: stream closed"}, err.Error()):
m.logger.Debug("failed to write response body", slog.String("error", err.Error()), slog.String("host", preq.Host), slog.String("method", preq.Method), slog.String("url", preq.URL.String()), slog.Any("headers", m.maskHeader(preq.Header)), slog.Int("status", res.StatusCode), slog.Any("response_headers", m.maskHeader(res.Header)))
case errors.Is(err, http.ErrBodyNotAllowed):
// It is desirable that there should be no content body in the response, but the proxy server cannot handle it, so it is used as a debug log.
m.logger.Debug("failed to write response body", slog.String("error", err.Error()), slog.String("host", preq.Host), slog.String("method", preq.Method), slog.String("url", preq.URL.String()), slog.Any("headers", m.maskHeader(preq.Header)), slog.Int("status", res.StatusCode), slog.Any("response_headers", m.maskHeader(res.Header)))
default:
m.logger.Error("failed to write response body", slog.String("error", err.Error()), slog.String("host", preq.Host), slog.String("method", preq.Method), slog.String("url", preq.URL.String()), slog.Any("headers", m.maskHeader(preq.Header)), slog.Int("status", res.StatusCode), slog.Any("response_headers", m.maskHeader(res.Header)))
}

ww := w.(io.Writer)
buf := getCopyBuf()
defer putCopyBuf(buf)
if _, err := io.CopyBuffer(ww, res.Body, buf); err != nil {
// Error as debug
// - os.ErrDeadlineExceeded: The request context has been canceled or has expired.
// - "client disconnected": The client disconnected. (net/http.http2errClientDisconnected)
// - "http2: stream closed": The client disconnected. (net/http.http2errStreamClosed)
// - syscall.ECONNRESET: The client disconnected. ("connection reset by peer")
// - syscall.EPIPE: The client disconnected. ("broken pipe")
// - http.ErrBodyNotAllowed: The request method does not allow a body.
switch {
case errors.Is(err, os.ErrDeadlineExceeded) || errors.Is(err, syscall.ECONNRESET) || errors.Is(err, syscall.EPIPE) || contains([]string{"client disconnected", "http2: stream closed"}, err.Error()):
m.logger.Debug("failed to write response body", slog.String("error", err.Error()), slog.String("host", reqc.Host), slog.String("method", reqc.Method), slog.String("url", reqc.URL.String()), slog.Any("headers", m.maskHeader(reqc.Header)), slog.Int("status", res.StatusCode), slog.Any("response_headers", m.maskHeader(res.Header)))
case errors.Is(err, http.ErrBodyNotAllowed):
// It is desirable that there should be no content body in the response, but the proxy server cannot handle it, so it is used as a debug log.
m.logger.Debug("failed to write response body", slog.String("error", err.Error()), slog.String("host", reqc.Host), slog.String("method", reqc.Method), slog.String("url", reqc.URL.String()), slog.Any("headers", m.maskHeader(reqc.Header)), slog.Int("status", res.StatusCode), slog.Any("response_headers", m.maskHeader(res.Header)))
default:
m.logger.Error("failed to write response body", slog.String("error", err.Error()), slog.String("host", reqc.Host), slog.String("method", reqc.Method), slog.String("url", reqc.URL.String()), slog.Any("headers", m.maskHeader(reqc.Header)), slog.Int("status", res.StatusCode), slog.Any("response_headers", m.maskHeader(res.Header)))
}
}
if err := res.Body.Close(); err != nil {
m.logger.Error("failed to close response body", slog.String("error", err.Error()), slog.String("host", preq.Host), slog.String("method", preq.Method), slog.String("url", preq.URL.String()), slog.Any("headers", m.maskHeader(preq.Header)), slog.Int("status", res.StatusCode), slog.Any("response_headers", m.maskHeader(res.Header)))
m.logger.Error("failed to close response body", slog.String("error", err.Error()), slog.String("host", reqc.Host), slog.String("method", reqc.Method), slog.String("url", reqc.URL.String()), slog.Any("headers", m.maskHeader(reqc.Header)), slog.Int("status", res.StatusCode), slog.Any("response_headers", m.maskHeader(res.Header)))
}

if cacheUsed {
m.logger.Debug("cache used", slog.String("host", preq.Host), slog.String("method", preq.Method), slog.String("url", preq.URL.String()), slog.Any("headers", m.maskHeader(preq.Header)), slog.Int("status", res.StatusCode))
return
}
ok, expires := m.cacher.Storable(preq, res, now)
if !ok {
m.logger.Debug("cache not storable", slog.String("host", preq.Host), slog.String("method", preq.Method), slog.String("url", preq.URL.String()), slog.Any("headers", m.maskHeader(preq.Header)), slog.Int("status", res.StatusCode), slog.Any("response_headers", m.maskHeader(res.Header)))
return
m.logger.Debug("cache used", slog.String("host", reqc.Host), slog.String("method", reqc.Method), slog.String("url", reqc.URL.String()), slog.Any("headers", m.maskHeader(reqc.Header)), slog.Int("status", res.StatusCode))
}
// Restore response body
res.Body = io.NopCloser(bytes.NewReader(body))

// Store response as cache
if err := m.cacher.Store(preq, res, expires); err != nil {
m.logger.Error("failed to store cache", slog.String("error", err.Error()), slog.String("host", preq.Host), slog.String("method", preq.Method), slog.String("url", preq.URL.String()), slog.Any("headers", m.maskHeader(preq.Header)), slog.Int("status", res.StatusCode))
}
m.logger.Debug("cache stored", slog.String("host", preq.Host), slog.String("method", preq.Method), slog.String("url", preq.URL.String()), slog.Any("headers", m.maskHeader(preq.Header)), slog.Int("status", res.StatusCode))
})
}

Expand All @@ -197,6 +183,32 @@ func (m *cacheMw) duplicateRequest(req *http.Request) (*http.Request, *http.Requ
return copy, req
}

func (m *cacheMw) handlerToRequester(h http.Handler, reqc *http.Request, now time.Time) func(*http.Request) (*http.Response, error) {
return func(req *http.Request) (*http.Response, error) {
rec := newRecorder()
defer rec.Reset()
h.ServeHTTP(rec, req)
res := rec.Result()
resc := rec.Result()

go func() {
ok, expires := m.cacher.Storable(reqc, resc, now)
if !ok {
m.logger.Debug("cache not storable", slog.String("host", reqc.Host), slog.String("method", reqc.Method), slog.String("url", reqc.URL.String()), slog.Any("headers", m.maskHeader(reqc.Header)), slog.Int("status", res.StatusCode), slog.Any("response_headers", m.maskHeader(resc.Header)))
return
}

// Store response as cache
if err := m.cacher.Store(reqc, resc, expires); err != nil {
m.logger.Error("failed to store cache", slog.String("error", err.Error()), slog.String("host", reqc.Host), slog.String("method", reqc.Method), slog.String("url", reqc.URL.String()), slog.Any("headers", m.maskHeader(reqc.Header)), slog.Int("status", resc.StatusCode))
}
m.logger.Debug("cache stored", slog.String("host", reqc.Host), slog.String("method", reqc.Method), slog.String("url", reqc.URL.String()), slog.Any("headers", m.maskHeader(reqc.Header)), slog.Int("status", resc.StatusCode))
}()

return res, nil
}
}

func (m *cacheMw) maskHeader(h http.Header) http.Header {
const masked = "*****"
c := h.Clone()
Expand Down Expand Up @@ -237,17 +249,6 @@ func New(cacher Cacher, opts ...Option) func(next http.Handler) http.Handler {
return rl.Handler
}

// HandlerToRequester converts http.Handler to func(*http.Request) (*http.Response, error).
func HandlerToRequester(h http.Handler) func(*http.Request) (*http.Response, error) {
return func(req *http.Request) (*http.Response, error) {
rec := newRecorder()
h.ServeHTTP(rec, req)
res := rec.Result()
res.Header = rec.Header()
return res, nil
}
}

type recorder struct {
statusCode int
header http.Header
Expand Down Expand Up @@ -280,11 +281,17 @@ func (r *recorder) Result() *http.Response {
Status: http.StatusText(r.statusCode),
StatusCode: r.statusCode,
Header: r.header.Clone(),
Body: io.NopCloser(r.buf),
Body: io.NopCloser(bytes.NewReader(r.buf.Bytes())),
ContentLength: int64(r.buf.Len()),
}
}

func (r *recorder) Reset() {
r.statusCode = 0
r.header = make(http.Header)
r.buf.Reset()
}

func contains(s []string, e string) bool {
for _, v := range s {
if e == v {
Expand Down

0 comments on commit 3781590

Please sign in to comment.