Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ require (
dario.cat/mergo v1.0.1 // indirect
github.com/Masterminds/goutils v1.1.1 // indirect
github.com/Masterminds/semver/v3 v3.3.0 // indirect
github.com/andybalholm/brotli v1.2.0 // indirect
github.com/beorn7/perks v1.0.1 // indirect
github.com/cenkalti/backoff/v5 v5.0.3 // indirect
github.com/felixge/httpsnoop v1.0.4 // indirect
Expand All @@ -53,6 +54,7 @@ require (
github.com/goccy/go-yaml v1.18.0 // indirect
github.com/grpc-ecosystem/grpc-gateway/v2 v2.27.2 // indirect
github.com/huandu/xstrings v1.5.0 // indirect
github.com/klauspost/compress v1.18.4 // indirect
github.com/mitchellh/copystructure v1.2.0 // indirect
github.com/mitchellh/reflectwalk v1.0.2 // indirect
github.com/prometheus/client_model v0.5.0 // indirect
Expand Down
4 changes: 4 additions & 0 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@ github.com/Masterminds/sprig/v3 v3.3.0 h1:mQh0Yrg1XPo6vjYXgtf5OtijNAKJRNcTdOOGZe
github.com/Masterminds/sprig/v3 v3.3.0/go.mod h1:Zy1iXRYNqNLUolqCpL4uhk6SHUMAOSCzdgBfDb35Lz0=
github.com/alicebob/miniredis/v2 v2.35.0 h1:QwLphYqCEAo1eu1TqPRN2jgVMPBweeQcR21jeqDCONI=
github.com/alicebob/miniredis/v2 v2.35.0/go.mod h1:TcL7YfarKPGDAthEtl5NBeHZfeUQj6OXMm/+iu5cLMM=
github.com/andybalholm/brotli v1.2.0 h1:ukwgCxwYrmACq68yiUqwIWnGY0cTPox/M94sVwToPjQ=
github.com/andybalholm/brotli v1.2.0/go.mod h1:rzTDkvFWvIrjDXZHkuS16NPggd91W3kUSvPlQ1pLaKY=
github.com/beorn7/perks v1.0.1 h1:VlbKKnNfV8bJzeqoa4cOKqO6bYr3WgKZxO8Z16+hsOM=
github.com/beorn7/perks v1.0.1/go.mod h1:G2ZrVWU2WbWT9wwq4/hrbKbnv/1ERSJQ0ibhJ6rlkpw=
github.com/bsm/ginkgo/v2 v2.12.0 h1:Ny8MWAHyOepLGlLKYmXG4IEkioBysk6GpaRTLC8zwWs=
Expand Down Expand Up @@ -112,6 +114,8 @@ github.com/josharian/intern v1.0.0 h1:vlS4z54oSdjm0bgjRigI+G1HpF+tI+9rE5LLzOg8Hm
github.com/josharian/intern v1.0.0/go.mod h1:5DoeVV0s6jJacbCEi61lwdGj/aVlrQvzHFFd8Hwg//Y=
github.com/json-iterator/go v1.1.12 h1:PV8peI4a0ysnczrg+LtxykD8LfKY9ML6u2jnxaEnrnM=
github.com/json-iterator/go v1.1.12/go.mod h1:e30LSqwooZae/UwlEbR2852Gd8hjQvJoHmT4TnhNGBo=
github.com/klauspost/compress v1.18.4 h1:RPhnKRAQ4Fh8zU2FY/6ZFDwTVTxgJ/EMydqSTzE9a2c=
github.com/klauspost/compress v1.18.4/go.mod h1:R0h/fSBs8DE4ENlcrlib3PsXS61voFxhIs2DeRhCvJ4=
github.com/klauspost/cpuid/v2 v2.3.0 h1:S4CRMLnYUhGeDFDqkGriYKdfoFlDnMtqTiI/sFzhA9Y=
github.com/klauspost/cpuid/v2 v2.3.0/go.mod h1:hqwkgyIinND0mEev00jJYCxPNVRVXFQeu1XKlok6oO0=
github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE=
Expand Down
170 changes: 170 additions & 0 deletions internal/core/execute_tool_test.go
Original file line number Diff line number Diff line change
@@ -1,15 +1,22 @@
package core

import (
"bytes"
"compress/flate"
"compress/gzip"
"compress/zlib"
"context"
"io"
"net/http"
"net/http/httptest"
"testing"

"github.com/andybalholm/brotli"
"github.com/amoylab/unla/internal/common/config"
"github.com/amoylab/unla/internal/mcp/session"
"github.com/amoylab/unla/pkg/mcp"
"github.com/gin-gonic/gin"
"github.com/klauspost/compress/zstd"
"github.com/stretchr/testify/assert"
"go.uber.org/zap"
)
Expand Down Expand Up @@ -74,3 +81,166 @@ func TestExecuteHTTPTool_ForwardHeadersAndRequestError(t *testing.T) {
assert.Error(t, err)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

suggestion (testing): Add coverage for multiple Content-Encoding values and encoding order in readDecodedResponseBody

Since readDecodedResponseBody handles multiple encodings by iterating them in reverse, please add at least one test with a chained Content-Encoding (e.g. gzip, br and/or br, gzip) to verify that: (1) decoding happens in the correct order, and (2) errors from any step in the chain are surfaced. This will validate the reverse-iteration logic and protect against regressions for multi-encoding responses.

Suggested implementation:

func TestReadDecodedResponseBody_GzipInvalidData(t *testing.T) {
	resp := &http.Response{
		Header: http.Header{"Content-Encoding": []string{"gzip"}},
		Body:   io.NopCloser(bytes.NewBufferString("not-gzip-data")),
	}

	_, err := readDecodedResponseBody(resp)
	assert.Error(t, err)
}

func encodeWithGzipThenZlib(t *testing.T, data []byte) []byte {
	var gzipBuf bytes.Buffer
	gzipWriter := gzip.NewWriter(&gzipBuf)
	_, err := gzipWriter.Write(data)
	if err != nil {
		t.Fatalf("failed to gzip-compress data: %v", err)
	}
	if err := gzipWriter.Close(); err != nil {
		t.Fatalf("failed to close gzip writer: %v", err)
	}

	var zlibBuf bytes.Buffer
	zlibWriter := zlib.NewWriter(&zlibBuf)
	_, err = zlibWriter.Write(gzipBuf.Bytes())
	if err != nil {
		t.Fatalf("failed to zlib-compress data: %v", err)
	}
	if err := zlibWriter.Close(); err != nil {
		t.Fatalf("failed to close zlib writer: %v", err)
	}

	return zlibBuf.Bytes()
}

func TestReadDecodedResponseBody_MultiEncoding_GzipThenDeflate_Success(t *testing.T) {
	original := []byte(`{"hello":"multi"}`)
	encoded := encodeWithGzipThenZlib(t, original)

	resp := &http.Response{
		Header: http.Header{
			// Encodings are listed in the order they were applied.
			// Body is zlib(deflate(gzip(original))).
			"Content-Encoding": []string{"gzip", "deflate"},
		},
		Body: io.NopCloser(bytes.NewReader(encoded)),
	}

	body, err := readDecodedResponseBody(resp)
	if assert.NoError(t, err) {
		assert.Equal(t, original, body)
	}
}

func TestReadDecodedResponseBody_MultiEncoding_ErrorInInnerEncoding(t *testing.T) {
	// Construct a body that can be successfully zlib-decoded, but is not valid gzip,
	// so that the second decoding step fails.
	var zlibBuf bytes.Buffer
	zlibWriter := zlib.NewWriter(&zlibBuf)
	_, err := zlibWriter.Write([]byte("not-gzip-data"))
	if err != nil {
		t.Fatalf("failed to zlib-compress data: %v", err)
	}
	if err := zlibWriter.Close(); err != nil {
		t.Fatalf("failed to close zlib writer: %v", err)
	}

	resp := &http.Response{
		Header: http.Header{
			// readDecodedResponseBody should decode in reverse order:
			// first "deflate" (zlib), then "gzip". The second step should fail.
			"Content-Encoding": []string{"gzip", "deflate"},
		},
		Body: io.NopCloser(bytes.NewReader(zlibBuf.Bytes())),
	}

	_, err = readDecodedResponseBody(resp)
	assert.Error(t, err)
}

These tests assume that readDecodedResponseBody treats "deflate" as zlib-wrapped data and that it iterates the encodings slice in reverse order, as suggested by your comment. If the implementation instead:

  1. Uses compress/flate directly for "deflate", or
  2. Parses a single header value like "gzip, deflate" rather than []string{"gzip", "deflate"},

you should adjust:

  • The encoder (encodeWithGzipThenZlib) to match the actual "deflate" implementation (e.g., using flate.NewWriter instead of zlib.NewWriter), and/or
  • The Content-Encoding header construction to mirror how your code reads/normalizes multiple encodings (e.g., Header: http.Header{"Content-Encoding": []string{"gzip, deflate"}}).

assert.Nil(t, res)
}

func TestExecuteHTTPTool_GzipResponse(t *testing.T) {
// downstream returns gzip-compressed JSON
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
var buf bytes.Buffer
gz := gzip.NewWriter(&buf)
_, _ = gz.Write([]byte(`{"hello":"gzip"}`))
_ = gz.Close()

w.Header().Set("Content-Type", "application/json")
w.Header().Set("Content-Encoding", "gzip")
_, _ = w.Write(buf.Bytes())
}))
defer srv.Close()

allowlist, _ := parseInternalNetworkAllowlist([]string{"127.0.0.0/8", "::1/128"})
s := &Server{logger: zap.NewNop(), toolRespHandler: CreateResponseHandlerChain(), internalNetACL: allowlist}
tool := &config.ToolConfig{
Name: "t",
Method: http.MethodGet,
Endpoint: srv.URL,
Headers: map[string]string{"Accept-Encoding": "gzip"},
ResponseBody: "{{.Response.Body}}",
}
req, _ := http.NewRequest(http.MethodGet, srv.URL, nil)
conn := &fakeConnExec{meta: &session.Meta{ID: "sid", Request: &session.RequestInfo{Headers: map[string]string{"X-Req": "v"}}}}
c, _ := gin.CreateTestContext(nil)
c.Request = req

res, err := s.executeHTTPTool(c, conn, tool, map[string]any{}, map[string]string{})
assert.NoError(t, err)
if assert.NotNil(t, res) {
if tc, ok := res.Content[0].(*mcp.TextContent); ok {
assert.Equal(t, `{"hello":"gzip"}`, tc.Text)
} else {
t.Fatalf("unexpected content type")
}
}
}

func TestReadDecodedResponseBody_GzipInvalidData(t *testing.T) {
resp := &http.Response{
Header: http.Header{"Content-Encoding": []string{"gzip"}},
Body: io.NopCloser(bytes.NewBufferString("not-gzip-data")),
}

_, err := readDecodedResponseBody(resp)
assert.Error(t, err)
}

func TestExecuteHTTPTool_BrotliResponse(t *testing.T) {
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
var buf bytes.Buffer
br := brotli.NewWriter(&buf)
_, _ = br.Write([]byte(`{"hello":"brotli"}`))
_ = br.Close()

w.Header().Set("Content-Type", "application/json")
w.Header().Set("Content-Encoding", "br")
_, _ = w.Write(buf.Bytes())
}))
defer srv.Close()

allowlist, _ := parseInternalNetworkAllowlist([]string{"127.0.0.0/8", "::1/128"})
s := &Server{logger: zap.NewNop(), toolRespHandler: CreateResponseHandlerChain(), internalNetACL: allowlist}
tool := &config.ToolConfig{Name: "t", Method: http.MethodGet, Endpoint: srv.URL, ResponseBody: "{{.Response.Body}}"}
req, _ := http.NewRequest(http.MethodGet, srv.URL, nil)
conn := &fakeConnExec{meta: &session.Meta{ID: "sid", Request: &session.RequestInfo{Headers: map[string]string{}}}}
c, _ := gin.CreateTestContext(nil)
c.Request = req

res, err := s.executeHTTPTool(c, conn, tool, map[string]any{}, map[string]string{})
assert.NoError(t, err)
if assert.NotNil(t, res) {
tc, ok := res.Content[0].(*mcp.TextContent)
if !ok {
t.Fatalf("unexpected content type")
}
assert.Equal(t, `{"hello":"brotli"}`, tc.Text)
}
}

func TestExecuteHTTPTool_ZstdResponse(t *testing.T) {
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
var buf bytes.Buffer
zw, err := zstd.NewWriter(&buf)
assert.NoError(t, err)
_, _ = zw.Write([]byte(`{"hello":"zstd"}`))
zw.Close()

w.Header().Set("Content-Type", "application/json")
w.Header().Set("Content-Encoding", "zstd")
_, _ = w.Write(buf.Bytes())
}))
defer srv.Close()

allowlist, _ := parseInternalNetworkAllowlist([]string{"127.0.0.0/8", "::1/128"})
s := &Server{logger: zap.NewNop(), toolRespHandler: CreateResponseHandlerChain(), internalNetACL: allowlist}
tool := &config.ToolConfig{Name: "t", Method: http.MethodGet, Endpoint: srv.URL, ResponseBody: "{{.Response.Body}}"}
req, _ := http.NewRequest(http.MethodGet, srv.URL, nil)
conn := &fakeConnExec{meta: &session.Meta{ID: "sid", Request: &session.RequestInfo{Headers: map[string]string{}}}}
c, _ := gin.CreateTestContext(nil)
c.Request = req

res, err := s.executeHTTPTool(c, conn, tool, map[string]any{}, map[string]string{})
assert.NoError(t, err)
if assert.NotNil(t, res) {
tc, ok := res.Content[0].(*mcp.TextContent)
if !ok {
t.Fatalf("unexpected content type")
}
assert.Equal(t, `{"hello":"zstd"}`, tc.Text)
}
}

func TestExecuteHTTPTool_DeflateResponse(t *testing.T) {
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
var buf bytes.Buffer
zw := zlib.NewWriter(&buf)
_, _ = zw.Write([]byte(`{"hello":"deflate"}`))
_ = zw.Close()

w.Header().Set("Content-Type", "application/json")
w.Header().Set("Content-Encoding", "deflate")
_, _ = w.Write(buf.Bytes())
}))
defer srv.Close()

allowlist, _ := parseInternalNetworkAllowlist([]string{"127.0.0.0/8", "::1/128"})
s := &Server{logger: zap.NewNop(), toolRespHandler: CreateResponseHandlerChain(), internalNetACL: allowlist}
tool := &config.ToolConfig{Name: "t", Method: http.MethodGet, Endpoint: srv.URL, ResponseBody: "{{.Response.Body}}"}
req, _ := http.NewRequest(http.MethodGet, srv.URL, nil)
conn := &fakeConnExec{meta: &session.Meta{ID: "sid", Request: &session.RequestInfo{Headers: map[string]string{}}}}
c, _ := gin.CreateTestContext(nil)
c.Request = req

res, err := s.executeHTTPTool(c, conn, tool, map[string]any{}, map[string]string{})
assert.NoError(t, err)
if assert.NotNil(t, res) {
tc, ok := res.Content[0].(*mcp.TextContent)
if !ok {
t.Fatalf("unexpected content type")
}
assert.Equal(t, `{"hello":"deflate"}`, tc.Text)
}
}

func TestReadDecodedResponseBody_DeflateRaw(t *testing.T) {
var buf bytes.Buffer
fw, err := flate.NewWriter(&buf, flate.DefaultCompression)
assert.NoError(t, err)
_, _ = fw.Write([]byte(`{"hello":"raw-deflate"}`))
_ = fw.Close()

resp := &http.Response{
Header: http.Header{"Content-Encoding": []string{"deflate"}},
Body: io.NopCloser(bytes.NewReader(buf.Bytes())),
}

body, err := readDecodedResponseBody(resp)
assert.NoError(t, err)
assert.Equal(t, `{"hello":"raw-deflate"}`, string(body))
}
108 changes: 105 additions & 3 deletions internal/core/tool.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,9 @@ package core

import (
"bytes"
"compress/flate"
"compress/gzip"
"compress/zlib"
"context"
"encoding/json"
"fmt"
Expand All @@ -13,6 +16,8 @@ import (
"slices"
"strings"

"github.com/andybalholm/brotli"
"github.com/klauspost/compress/zstd"
"github.com/gin-gonic/gin"
"go.uber.org/zap"
"golang.org/x/net/proxy"
Expand All @@ -30,6 +35,101 @@ import (
"github.com/amoylab/unla/pkg/mcp"
)

func parseContentEncodings(contentEncoding string) []string {
parts := strings.Split(contentEncoding, ",")
encodings := make([]string, 0, len(parts))
for _, part := range parts {
encoding := strings.ToLower(strings.TrimSpace(part))
if encoding == "" || encoding == "identity" {
continue
}
encodings = append(encodings, encoding)
}
return encodings
}

func decodeBodyBytesByEncoding(body []byte, encoding string) ([]byte, error) {
var (
reader io.Reader = bytes.NewReader(body)
closeFn func() error
decodeError error
)

switch encoding {
case "gzip", "x-gzip":
gzReader, err := gzip.NewReader(reader)
if err != nil {
return nil, fmt.Errorf("failed to create gzip reader: %w", err)
}
reader = gzReader
closeFn = gzReader.Close
case "br":
reader = brotli.NewReader(reader)
case "zstd":
zstdReader, err := zstd.NewReader(reader)
if err != nil {
return nil, fmt.Errorf("failed to create zstd reader: %w", err)
}
reader = zstdReader
closeFn = func() error {
zstdReader.Close()
return nil
}
Comment on lines +68 to +77
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

suggestion (bug_risk): Preserve and surface errors from zstdReader.Close instead of discarding them

In the zstd branch, zstdReader.Close() is wrapped in a closeFn that always returns nil, so any close error is silently dropped while closeErr is still being checked by callers. To align with the other codecs and preserve potential failures, you can wire the close error through:

case "zstd":
	zstdReader, err := zstd.NewReader(reader)
	if err != nil {
		return nil, fmt.Errorf("failed to create zstd reader: %w", err)
	}
	reader = zstdReader
	closeFn = func() error {
		return zstdReader.Close()
	}

This keeps error handling consistent and ensures zstd close errors are observable to callers.

Suggested change
case "zstd":
zstdReader, err := zstd.NewReader(reader)
if err != nil {
return nil, fmt.Errorf("failed to create zstd reader: %w", err)
}
reader = zstdReader
closeFn = func() error {
zstdReader.Close()
return nil
}
case "zstd":
zstdReader, err := zstd.NewReader(reader)
if err != nil {
return nil, fmt.Errorf("failed to create zstd reader: %w", err)
}
reader = zstdReader
closeFn = func() error {
return zstdReader.Close()
}

case "deflate":
zlibReader, err := zlib.NewReader(reader)
if err == nil {
reader = zlibReader
closeFn = zlibReader.Close
break
}
// Some servers send raw deflate stream for deflate token.
flateReader := flate.NewReader(bytes.NewReader(body))
reader = flateReader
closeFn = flateReader.Close
default:
return nil, fmt.Errorf("unsupported content encoding: %s", encoding)
}

decoded, err := io.ReadAll(reader)
if err != nil {
decodeError = fmt.Errorf("failed to decode %s body: %w", encoding, err)
}
if closeFn != nil {
closeErr := closeFn()
if decodeError != nil {
return nil, decodeError
}
if closeErr != nil {
return nil, fmt.Errorf("failed to close %s decoder: %w", encoding, closeErr)
}
}
if decodeError != nil {
return nil, decodeError
}

return decoded, nil
}

func readDecodedResponseBody(resp *http.Response) ([]byte, error) {
if resp == nil || resp.Body == nil {
return nil, nil
}

body, err := io.ReadAll(resp.Body)
if err != nil {
return nil, fmt.Errorf("failed to read response body: %w", err)
}

encodings := parseContentEncodings(resp.Header.Get("Content-Encoding"))
for i := len(encodings) - 1; i >= 0; i-- {
body, err = decodeBodyBytesByEncoding(body, encodings[i])
if err != nil {
return nil, err
}
}
return body, nil
}

// shouldIgnoreHeader checks if a header should be ignored based on configuration
func (s *Server) shouldIgnoreHeader(headerName string) bool {
// If forward is disabled, don't ignore any headers (backward compatibility)
Expand Down Expand Up @@ -330,8 +430,8 @@ func (s *Server) executeHTTPTool(c *gin.Context, conn session.Connection, tool *
}
defer resp.Body.Close()

// Read response body for logging in case of error
respBodyBytes, err := io.ReadAll(resp.Body)
// Read and decode response body for logging and unified downstream processing.
respBodyBytes, err := readDecodedResponseBody(resp)
if err != nil {
logger.Error("failed to read response body",
zap.String("tool", tool.Name),
Expand All @@ -341,8 +441,10 @@ func (s *Server) executeHTTPTool(c *gin.Context, conn session.Connection, tool *
return nil, fmt.Errorf("failed to read response body: %w", err)
}

// Restore response body for further processing
// Restore decoded response body for further processing.
resp.Body = io.NopCloser(bytes.NewBuffer(respBodyBytes))
resp.Header.Del("Content-Encoding")
resp.Header.Del("Content-Length")
Comment on lines +444 to +447
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

suggestion (bug_risk): Consider updating resp.ContentLength alongside the Content-Length header after decoding

Right now the resp.ContentLength field still reflects the original (encoded) length, so it can disagree with the decoded body you’ve set on resp.Body. Any caller using resp.ContentLength will see the wrong value. After decoding, also update the field, e.g.:

resp.ContentLength = int64(len(respBodyBytes)) // or -1 if you want it treated as unknown
Suggested change
// Restore decoded response body for further processing.
resp.Body = io.NopCloser(bytes.NewBuffer(respBodyBytes))
resp.Header.Del("Content-Encoding")
resp.Header.Del("Content-Length")
// Restore decoded response body for further processing.
resp.Body = io.NopCloser(bytes.NewBuffer(respBodyBytes))
resp.ContentLength = int64(len(respBodyBytes))
resp.Header.Del("Content-Encoding")
resp.Header.Del("Content-Length")


// Log response status
logger.Debug("received HTTP response",
Expand Down
Loading