diff --git a/go.mod b/go.mod index 0da8eb07..d9d00daa 100644 --- a/go.mod +++ b/go.mod @@ -45,6 +45,7 @@ require ( dario.cat/mergo v1.0.1 // indirect github.com/Masterminds/goutils v1.1.1 // indirect github.com/Masterminds/semver/v3 v3.3.0 // indirect + github.com/andybalholm/brotli v1.2.0 // indirect github.com/beorn7/perks v1.0.1 // indirect github.com/cenkalti/backoff/v5 v5.0.3 // indirect github.com/felixge/httpsnoop v1.0.4 // indirect @@ -53,6 +54,7 @@ require ( github.com/goccy/go-yaml v1.18.0 // indirect github.com/grpc-ecosystem/grpc-gateway/v2 v2.27.2 // indirect github.com/huandu/xstrings v1.5.0 // indirect + github.com/klauspost/compress v1.18.4 // indirect github.com/mitchellh/copystructure v1.2.0 // indirect github.com/mitchellh/reflectwalk v1.0.2 // indirect github.com/prometheus/client_model v0.5.0 // indirect diff --git a/go.sum b/go.sum index 6e6b2e08..fcd0b569 100644 --- a/go.sum +++ b/go.sum @@ -10,6 +10,8 @@ github.com/Masterminds/sprig/v3 v3.3.0 h1:mQh0Yrg1XPo6vjYXgtf5OtijNAKJRNcTdOOGZe github.com/Masterminds/sprig/v3 v3.3.0/go.mod h1:Zy1iXRYNqNLUolqCpL4uhk6SHUMAOSCzdgBfDb35Lz0= github.com/alicebob/miniredis/v2 v2.35.0 h1:QwLphYqCEAo1eu1TqPRN2jgVMPBweeQcR21jeqDCONI= github.com/alicebob/miniredis/v2 v2.35.0/go.mod h1:TcL7YfarKPGDAthEtl5NBeHZfeUQj6OXMm/+iu5cLMM= +github.com/andybalholm/brotli v1.2.0 h1:ukwgCxwYrmACq68yiUqwIWnGY0cTPox/M94sVwToPjQ= +github.com/andybalholm/brotli v1.2.0/go.mod h1:rzTDkvFWvIrjDXZHkuS16NPggd91W3kUSvPlQ1pLaKY= github.com/beorn7/perks v1.0.1 h1:VlbKKnNfV8bJzeqoa4cOKqO6bYr3WgKZxO8Z16+hsOM= github.com/beorn7/perks v1.0.1/go.mod h1:G2ZrVWU2WbWT9wwq4/hrbKbnv/1ERSJQ0ibhJ6rlkpw= github.com/bsm/ginkgo/v2 v2.12.0 h1:Ny8MWAHyOepLGlLKYmXG4IEkioBysk6GpaRTLC8zwWs= @@ -112,6 +114,8 @@ github.com/josharian/intern v1.0.0 h1:vlS4z54oSdjm0bgjRigI+G1HpF+tI+9rE5LLzOg8Hm github.com/josharian/intern v1.0.0/go.mod h1:5DoeVV0s6jJacbCEi61lwdGj/aVlrQvzHFFd8Hwg//Y= github.com/json-iterator/go v1.1.12 h1:PV8peI4a0ysnczrg+LtxykD8LfKY9ML6u2jnxaEnrnM= github.com/json-iterator/go v1.1.12/go.mod h1:e30LSqwooZae/UwlEbR2852Gd8hjQvJoHmT4TnhNGBo= +github.com/klauspost/compress v1.18.4 h1:RPhnKRAQ4Fh8zU2FY/6ZFDwTVTxgJ/EMydqSTzE9a2c= +github.com/klauspost/compress v1.18.4/go.mod h1:R0h/fSBs8DE4ENlcrlib3PsXS61voFxhIs2DeRhCvJ4= github.com/klauspost/cpuid/v2 v2.3.0 h1:S4CRMLnYUhGeDFDqkGriYKdfoFlDnMtqTiI/sFzhA9Y= github.com/klauspost/cpuid/v2 v2.3.0/go.mod h1:hqwkgyIinND0mEev00jJYCxPNVRVXFQeu1XKlok6oO0= github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE= diff --git a/internal/core/execute_tool_test.go b/internal/core/execute_tool_test.go index 67166512..928239df 100644 --- a/internal/core/execute_tool_test.go +++ b/internal/core/execute_tool_test.go @@ -1,15 +1,22 @@ package core import ( + "bytes" + "compress/flate" + "compress/gzip" + "compress/zlib" "context" + "io" "net/http" "net/http/httptest" "testing" + "github.com/andybalholm/brotli" "github.com/amoylab/unla/internal/common/config" "github.com/amoylab/unla/internal/mcp/session" "github.com/amoylab/unla/pkg/mcp" "github.com/gin-gonic/gin" + "github.com/klauspost/compress/zstd" "github.com/stretchr/testify/assert" "go.uber.org/zap" ) @@ -74,3 +81,166 @@ func TestExecuteHTTPTool_ForwardHeadersAndRequestError(t *testing.T) { assert.Error(t, err) assert.Nil(t, res) } + +func TestExecuteHTTPTool_GzipResponse(t *testing.T) { + // downstream returns gzip-compressed JSON + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + var buf bytes.Buffer + gz := gzip.NewWriter(&buf) + _, _ = gz.Write([]byte(`{"hello":"gzip"}`)) + _ = gz.Close() + + w.Header().Set("Content-Type", "application/json") + w.Header().Set("Content-Encoding", "gzip") + _, _ = w.Write(buf.Bytes()) + })) + defer srv.Close() + + allowlist, _ := parseInternalNetworkAllowlist([]string{"127.0.0.0/8", "::1/128"}) + s := &Server{logger: zap.NewNop(), toolRespHandler: CreateResponseHandlerChain(), internalNetACL: allowlist} + tool := &config.ToolConfig{ + Name: "t", + Method: http.MethodGet, + Endpoint: srv.URL, + Headers: map[string]string{"Accept-Encoding": "gzip"}, + ResponseBody: "{{.Response.Body}}", + } + req, _ := http.NewRequest(http.MethodGet, srv.URL, nil) + conn := &fakeConnExec{meta: &session.Meta{ID: "sid", Request: &session.RequestInfo{Headers: map[string]string{"X-Req": "v"}}}} + c, _ := gin.CreateTestContext(nil) + c.Request = req + + res, err := s.executeHTTPTool(c, conn, tool, map[string]any{}, map[string]string{}) + assert.NoError(t, err) + if assert.NotNil(t, res) { + if tc, ok := res.Content[0].(*mcp.TextContent); ok { + assert.Equal(t, `{"hello":"gzip"}`, tc.Text) + } else { + t.Fatalf("unexpected content type") + } + } +} + +func TestReadDecodedResponseBody_GzipInvalidData(t *testing.T) { + resp := &http.Response{ + Header: http.Header{"Content-Encoding": []string{"gzip"}}, + Body: io.NopCloser(bytes.NewBufferString("not-gzip-data")), + } + + _, err := readDecodedResponseBody(resp) + assert.Error(t, err) +} + +func TestExecuteHTTPTool_BrotliResponse(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + var buf bytes.Buffer + br := brotli.NewWriter(&buf) + _, _ = br.Write([]byte(`{"hello":"brotli"}`)) + _ = br.Close() + + w.Header().Set("Content-Type", "application/json") + w.Header().Set("Content-Encoding", "br") + _, _ = w.Write(buf.Bytes()) + })) + defer srv.Close() + + allowlist, _ := parseInternalNetworkAllowlist([]string{"127.0.0.0/8", "::1/128"}) + s := &Server{logger: zap.NewNop(), toolRespHandler: CreateResponseHandlerChain(), internalNetACL: allowlist} + tool := &config.ToolConfig{Name: "t", Method: http.MethodGet, Endpoint: srv.URL, ResponseBody: "{{.Response.Body}}"} + req, _ := http.NewRequest(http.MethodGet, srv.URL, nil) + conn := &fakeConnExec{meta: &session.Meta{ID: "sid", Request: &session.RequestInfo{Headers: map[string]string{}}}} + c, _ := gin.CreateTestContext(nil) + c.Request = req + + res, err := s.executeHTTPTool(c, conn, tool, map[string]any{}, map[string]string{}) + assert.NoError(t, err) + if assert.NotNil(t, res) { + tc, ok := res.Content[0].(*mcp.TextContent) + if !ok { + t.Fatalf("unexpected content type") + } + assert.Equal(t, `{"hello":"brotli"}`, tc.Text) + } +} + +func TestExecuteHTTPTool_ZstdResponse(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + var buf bytes.Buffer + zw, err := zstd.NewWriter(&buf) + assert.NoError(t, err) + _, _ = zw.Write([]byte(`{"hello":"zstd"}`)) + zw.Close() + + w.Header().Set("Content-Type", "application/json") + w.Header().Set("Content-Encoding", "zstd") + _, _ = w.Write(buf.Bytes()) + })) + defer srv.Close() + + allowlist, _ := parseInternalNetworkAllowlist([]string{"127.0.0.0/8", "::1/128"}) + s := &Server{logger: zap.NewNop(), toolRespHandler: CreateResponseHandlerChain(), internalNetACL: allowlist} + tool := &config.ToolConfig{Name: "t", Method: http.MethodGet, Endpoint: srv.URL, ResponseBody: "{{.Response.Body}}"} + req, _ := http.NewRequest(http.MethodGet, srv.URL, nil) + conn := &fakeConnExec{meta: &session.Meta{ID: "sid", Request: &session.RequestInfo{Headers: map[string]string{}}}} + c, _ := gin.CreateTestContext(nil) + c.Request = req + + res, err := s.executeHTTPTool(c, conn, tool, map[string]any{}, map[string]string{}) + assert.NoError(t, err) + if assert.NotNil(t, res) { + tc, ok := res.Content[0].(*mcp.TextContent) + if !ok { + t.Fatalf("unexpected content type") + } + assert.Equal(t, `{"hello":"zstd"}`, tc.Text) + } +} + +func TestExecuteHTTPTool_DeflateResponse(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + var buf bytes.Buffer + zw := zlib.NewWriter(&buf) + _, _ = zw.Write([]byte(`{"hello":"deflate"}`)) + _ = zw.Close() + + w.Header().Set("Content-Type", "application/json") + w.Header().Set("Content-Encoding", "deflate") + _, _ = w.Write(buf.Bytes()) + })) + defer srv.Close() + + allowlist, _ := parseInternalNetworkAllowlist([]string{"127.0.0.0/8", "::1/128"}) + s := &Server{logger: zap.NewNop(), toolRespHandler: CreateResponseHandlerChain(), internalNetACL: allowlist} + tool := &config.ToolConfig{Name: "t", Method: http.MethodGet, Endpoint: srv.URL, ResponseBody: "{{.Response.Body}}"} + req, _ := http.NewRequest(http.MethodGet, srv.URL, nil) + conn := &fakeConnExec{meta: &session.Meta{ID: "sid", Request: &session.RequestInfo{Headers: map[string]string{}}}} + c, _ := gin.CreateTestContext(nil) + c.Request = req + + res, err := s.executeHTTPTool(c, conn, tool, map[string]any{}, map[string]string{}) + assert.NoError(t, err) + if assert.NotNil(t, res) { + tc, ok := res.Content[0].(*mcp.TextContent) + if !ok { + t.Fatalf("unexpected content type") + } + assert.Equal(t, `{"hello":"deflate"}`, tc.Text) + } +} + +func TestReadDecodedResponseBody_DeflateRaw(t *testing.T) { + var buf bytes.Buffer + fw, err := flate.NewWriter(&buf, flate.DefaultCompression) + assert.NoError(t, err) + _, _ = fw.Write([]byte(`{"hello":"raw-deflate"}`)) + _ = fw.Close() + + resp := &http.Response{ + Header: http.Header{"Content-Encoding": []string{"deflate"}}, + Body: io.NopCloser(bytes.NewReader(buf.Bytes())), + } + + body, err := readDecodedResponseBody(resp) + assert.NoError(t, err) + assert.Equal(t, `{"hello":"raw-deflate"}`, string(body)) +} diff --git a/internal/core/tool.go b/internal/core/tool.go index 856cf1e4..018a422a 100644 --- a/internal/core/tool.go +++ b/internal/core/tool.go @@ -2,6 +2,9 @@ package core import ( "bytes" + "compress/flate" + "compress/gzip" + "compress/zlib" "context" "encoding/json" "fmt" @@ -13,6 +16,8 @@ import ( "slices" "strings" + "github.com/andybalholm/brotli" + "github.com/klauspost/compress/zstd" "github.com/gin-gonic/gin" "go.uber.org/zap" "golang.org/x/net/proxy" @@ -30,6 +35,101 @@ import ( "github.com/amoylab/unla/pkg/mcp" ) +func parseContentEncodings(contentEncoding string) []string { + parts := strings.Split(contentEncoding, ",") + encodings := make([]string, 0, len(parts)) + for _, part := range parts { + encoding := strings.ToLower(strings.TrimSpace(part)) + if encoding == "" || encoding == "identity" { + continue + } + encodings = append(encodings, encoding) + } + return encodings +} + +func decodeBodyBytesByEncoding(body []byte, encoding string) ([]byte, error) { + var ( + reader io.Reader = bytes.NewReader(body) + closeFn func() error + decodeError error + ) + + switch encoding { + case "gzip", "x-gzip": + gzReader, err := gzip.NewReader(reader) + if err != nil { + return nil, fmt.Errorf("failed to create gzip reader: %w", err) + } + reader = gzReader + closeFn = gzReader.Close + case "br": + reader = brotli.NewReader(reader) + case "zstd": + zstdReader, err := zstd.NewReader(reader) + if err != nil { + return nil, fmt.Errorf("failed to create zstd reader: %w", err) + } + reader = zstdReader + closeFn = func() error { + zstdReader.Close() + return nil + } + case "deflate": + zlibReader, err := zlib.NewReader(reader) + if err == nil { + reader = zlibReader + closeFn = zlibReader.Close + break + } + // Some servers send raw deflate stream for deflate token. + flateReader := flate.NewReader(bytes.NewReader(body)) + reader = flateReader + closeFn = flateReader.Close + default: + return nil, fmt.Errorf("unsupported content encoding: %s", encoding) + } + + decoded, err := io.ReadAll(reader) + if err != nil { + decodeError = fmt.Errorf("failed to decode %s body: %w", encoding, err) + } + if closeFn != nil { + closeErr := closeFn() + if decodeError != nil { + return nil, decodeError + } + if closeErr != nil { + return nil, fmt.Errorf("failed to close %s decoder: %w", encoding, closeErr) + } + } + if decodeError != nil { + return nil, decodeError + } + + return decoded, nil +} + +func readDecodedResponseBody(resp *http.Response) ([]byte, error) { + if resp == nil || resp.Body == nil { + return nil, nil + } + + body, err := io.ReadAll(resp.Body) + if err != nil { + return nil, fmt.Errorf("failed to read response body: %w", err) + } + + encodings := parseContentEncodings(resp.Header.Get("Content-Encoding")) + for i := len(encodings) - 1; i >= 0; i-- { + body, err = decodeBodyBytesByEncoding(body, encodings[i]) + if err != nil { + return nil, err + } + } + return body, nil +} + // shouldIgnoreHeader checks if a header should be ignored based on configuration func (s *Server) shouldIgnoreHeader(headerName string) bool { // If forward is disabled, don't ignore any headers (backward compatibility) @@ -330,8 +430,8 @@ func (s *Server) executeHTTPTool(c *gin.Context, conn session.Connection, tool * } defer resp.Body.Close() - // Read response body for logging in case of error - respBodyBytes, err := io.ReadAll(resp.Body) + // Read and decode response body for logging and unified downstream processing. + respBodyBytes, err := readDecodedResponseBody(resp) if err != nil { logger.Error("failed to read response body", zap.String("tool", tool.Name), @@ -341,8 +441,10 @@ func (s *Server) executeHTTPTool(c *gin.Context, conn session.Connection, tool * return nil, fmt.Errorf("failed to read response body: %w", err) } - // Restore response body for further processing + // Restore decoded response body for further processing. resp.Body = io.NopCloser(bytes.NewBuffer(respBodyBytes)) + resp.Header.Del("Content-Encoding") + resp.Header.Del("Content-Length") // Log response status logger.Debug("received HTTP response",