Skip to content

Commit 9d7f9d0

Browse files
authored
feat: refactor backend architecture with interface pattern and trusted proxy support (#72)
* feat: refactor backend architecture with interface pattern and trusted proxy support - Extract backend interface for better separation of concerns - Split monolithic backend.go into separate proxy and transparent backends - Add trusted proxy configuration support with IP/CIDR validation - Implement transparent backend for HTTP/HTTPS proxy targets - Add comprehensive test coverage for new backend implementations - Maintain backward compatibility with existing proxy functionality * fix: enable trusted proxy configuration Uncommented the SetTrustedProxies call to properly configure trusted proxies
1 parent a89634f commit 9d7f9d0

File tree

8 files changed

+311
-5
lines changed

8 files changed

+311
-5
lines changed

main.go

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,7 @@ func main() {
5353
var passwordHash string
5454
var proxyBearerToken string
5555
var proxyHeaders string
56+
var trustedProxies string
5657

5758
rootCmd := &cobra.Command{
5859
Use: "mcp-warp",
@@ -107,6 +108,14 @@ func main() {
107108
oidcScopesList = []string{"openid", "profile", "email"}
108109
}
109110

111+
var trustedProxiesList []string
112+
if trustedProxies != "" {
113+
trustedProxiesList = strings.Split(trustedProxies, ",")
114+
for i := range trustedProxiesList {
115+
trustedProxiesList[i] = strings.TrimSpace(trustedProxiesList[i])
116+
}
117+
}
118+
110119
// Parse proxy headers into slice
111120
var proxyHeadersList []string
112121
if proxyHeaders != "" {
@@ -142,6 +151,7 @@ func main() {
142151
oidcAllowedUsersList,
143152
password,
144153
passwordHash,
154+
trustedProxiesList,
145155
proxyHeadersList,
146156
proxyBearerToken,
147157
args,
@@ -187,6 +197,7 @@ func main() {
187197

188198
// Proxy headers configuration
189199
rootCmd.Flags().StringVar(&proxyBearerToken, "proxy-bearer-token", getEnvWithDefault("PROXY_BEARER_TOKEN", ""), "Bearer token to add to Authorization header when proxying requests")
200+
rootCmd.Flags().StringVar(&trustedProxies, "trusted-proxies", getEnvWithDefault("TRUSTED_PROXIES", ""), "Comma-separated list of trusted proxies (IP addresses or CIDR ranges)")
190201
rootCmd.Flags().StringVar(&proxyHeaders, "proxy-headers", getEnvWithDefault("PROXY_HEADERS", ""), "Comma-separated list of headers to add when proxying requests (format: Header1:Value1,Header2:Value2)")
191202

192203
if err := rootCmd.Execute(); err != nil {

pkg/backend/interface.go

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
package backend
2+
3+
import (
4+
"context"
5+
"net/http"
6+
)
7+
8+
type Backend interface {
9+
Run(context.Context) (http.Handler, error)
10+
Wait() error
11+
Close() error
12+
}

pkg/backend/main_test.go

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
package backend
2+
3+
import (
4+
"os"
5+
"testing"
6+
7+
"github.com/gin-gonic/gin"
8+
)
9+
10+
func TestMain(m *testing.M) {
11+
gin.SetMode(gin.TestMode)
12+
os.Exit(m.Run())
13+
}
Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ type ProxyBackend struct {
2525
client *client.Client
2626
}
2727

28-
func NewProxyBackend(logger *zap.Logger, cmd []string) *ProxyBackend {
28+
func NewProxyBackend(logger *zap.Logger, cmd []string) Backend {
2929
return &ProxyBackend{
3030
logger: logger,
3131
cmd: cmd,
Lines changed: 24 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -78,9 +78,32 @@ func TestProxyBackendRun(t *testing.T) {
7878
defer pb.Close()
7979

8080
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
81-
defer cancel()
8281

8382
handler, err := pb.Run(ctx)
8483
require.NoError(t, err, "Run should not return error")
8584
require.NotNil(t, handler, "handler should not be nil")
85+
86+
checkCh := make(chan struct{})
87+
go func() {
88+
<-ctx.Done()
89+
close(checkCh)
90+
}()
91+
92+
timeout := time.After(10 * time.Millisecond)
93+
select {
94+
case <-checkCh:
95+
t.Error("Test completed too early")
96+
case <-timeout:
97+
// Test timed out
98+
}
99+
100+
cancel()
101+
102+
timeout = time.After(10 * time.Second)
103+
select {
104+
case <-checkCh:
105+
// Test completed successfully
106+
case <-timeout:
107+
t.Error("Test timed out")
108+
}
86109
}

pkg/backend/transparent.go

Lines changed: 100 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,100 @@
1+
package backend
2+
3+
import (
4+
"context"
5+
"fmt"
6+
"net"
7+
"net/http"
8+
"net/http/httputil"
9+
"net/netip"
10+
"net/url"
11+
"sync"
12+
13+
"go.uber.org/zap"
14+
)
15+
16+
type TransparentBackend struct {
17+
logger *zap.Logger
18+
url *url.URL
19+
trusted []netip.Prefix
20+
ctx context.Context
21+
ctxLock sync.Mutex
22+
}
23+
24+
func NewTransparentBackend(logger *zap.Logger, u *url.URL, trusted []string) (Backend, error) {
25+
trn := make([]netip.Prefix, 0, len(trusted))
26+
for _, c := range trusted {
27+
p, err := netip.ParsePrefix(c)
28+
if err != nil {
29+
return nil, err
30+
}
31+
trn = append(trn, p)
32+
}
33+
34+
return &TransparentBackend{
35+
logger: logger,
36+
url: u,
37+
trusted: trn,
38+
}, nil
39+
}
40+
41+
func (p *TransparentBackend) Run(ctx context.Context) (http.Handler, error) {
42+
p.ctxLock.Lock()
43+
defer p.ctxLock.Unlock()
44+
if p.ctx != nil {
45+
return nil, fmt.Errorf("transparent backend is already running")
46+
}
47+
p.ctx = ctx
48+
rp := httputil.ReverseProxy{
49+
Rewrite: func(pr *httputil.ProxyRequest) {
50+
pr.SetURL(p.url)
51+
if p.isTrusted(pr.In.RemoteAddr) {
52+
pr.Out.Header["X-Forwarded-For"] = pr.In.Header["X-Forwarded-For"]
53+
}
54+
pr.SetXForwarded()
55+
if p.isTrusted(pr.In.RemoteAddr) {
56+
if v := pr.In.Header.Get("X-Forwarded-Host"); v != "" {
57+
pr.Out.Header.Set("X-Forwarded-Host", v)
58+
}
59+
if v := pr.In.Header.Get("X-Forwarded-Proto"); v != "" {
60+
pr.Out.Header.Set("X-Forwarded-Proto", v)
61+
}
62+
if v := pr.In.Header.Get("X-Forwarded-Port"); v != "" {
63+
pr.Out.Header.Set("X-Forwarded-Port", v)
64+
}
65+
}
66+
},
67+
}
68+
return &rp, nil
69+
}
70+
71+
func (p *TransparentBackend) isTrusted(hostport string) bool {
72+
if host, _, err := net.SplitHostPort(hostport); err == nil {
73+
hostport = host
74+
}
75+
ip, err := netip.ParseAddr(hostport)
76+
if err != nil {
77+
return false
78+
}
79+
if ip.Is4In6() {
80+
ip = ip.Unmap()
81+
}
82+
for _, p := range p.trusted {
83+
if p.Contains(ip) {
84+
return true
85+
}
86+
}
87+
return false
88+
}
89+
90+
func (p *TransparentBackend) Wait() error {
91+
if p.ctx == nil {
92+
return nil
93+
}
94+
<-p.ctx.Done()
95+
return nil
96+
}
97+
98+
func (p *TransparentBackend) Close() error {
99+
return nil
100+
}

pkg/backend/transparent_test.go

Lines changed: 138 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,138 @@
1+
package backend
2+
3+
import (
4+
"context"
5+
"encoding/json"
6+
"net/http"
7+
"net/http/httptest"
8+
"net/url"
9+
"testing"
10+
"time"
11+
12+
"github.com/gin-gonic/gin"
13+
"github.com/stretchr/testify/require"
14+
"go.uber.org/zap"
15+
)
16+
17+
func TestTransparentBackend(t *testing.T) {
18+
r := gin.New()
19+
r.GET("/", func(c *gin.Context) {
20+
c.JSON(http.StatusOK, c.Request.Header)
21+
})
22+
ts := httptest.NewServer(r)
23+
u, _ := url.Parse(ts.URL)
24+
25+
be, err := NewTransparentBackend(zap.NewNop(), u, []string{})
26+
require.NoError(t, err)
27+
handler, err := be.Run(context.Background())
28+
require.NoError(t, err)
29+
require.NotNil(t, handler)
30+
31+
req := httptest.NewRequest(http.MethodGet, "/", nil)
32+
rr := httptest.NewRecorder()
33+
handler.ServeHTTP(rr, req)
34+
35+
require.Equal(t, http.StatusOK, rr.Code)
36+
var header http.Header
37+
require.NoError(t, json.Unmarshal(rr.Body.Bytes(), &header))
38+
require.Equal(t, "192.0.2.1", header.Get(("X-Forwarded-For")))
39+
require.Equal(t, "example.com", header.Get(("X-Forwarded-Host")))
40+
require.Equal(t, "http", header.Get(("X-Forwarded-Proto")))
41+
}
42+
43+
func TestTransparentBackendWithProxy(t *testing.T) {
44+
r := gin.New()
45+
r.GET("/", func(c *gin.Context) {
46+
c.JSON(http.StatusOK, c.Request.Header)
47+
})
48+
ts := httptest.NewServer(r)
49+
u, _ := url.Parse(ts.URL)
50+
51+
be, err := NewTransparentBackend(zap.NewNop(), u, []string{"0.0.0.0/0"})
52+
require.NoError(t, err)
53+
handler, err := be.Run(context.Background())
54+
require.NoError(t, err)
55+
require.NotNil(t, handler)
56+
57+
req := httptest.NewRequest(http.MethodGet, "/", nil)
58+
req.Header.Set("X-Forwarded-For", "192.0.3.1")
59+
req.Header.Set("X-Forwarded-Host", "example.org")
60+
req.Header.Set("X-Forwarded-Proto", "https")
61+
rr := httptest.NewRecorder()
62+
handler.ServeHTTP(rr, req)
63+
64+
require.Equal(t, http.StatusOK, rr.Code)
65+
var header http.Header
66+
require.NoError(t, json.Unmarshal(rr.Body.Bytes(), &header))
67+
require.Equal(t, "192.0.3.1, 192.0.2.1", header.Get(("X-Forwarded-For")))
68+
require.Equal(t, "example.org", header.Get(("X-Forwarded-Host")))
69+
require.Equal(t, "https", header.Get(("X-Forwarded-Proto")))
70+
}
71+
72+
func TestTransparentBackendWithInvalidProxy(t *testing.T) {
73+
r := gin.New()
74+
r.GET("/", func(c *gin.Context) {
75+
c.JSON(http.StatusOK, c.Request.Header)
76+
})
77+
ts := httptest.NewServer(r)
78+
u, _ := url.Parse(ts.URL)
79+
80+
be, err := NewTransparentBackend(zap.NewNop(), u, []string{"1.1.1.1/32"})
81+
require.NoError(t, err)
82+
handler, err := be.Run(context.Background())
83+
require.NoError(t, err)
84+
require.NotNil(t, handler)
85+
86+
req := httptest.NewRequest(http.MethodGet, "/", nil)
87+
req.Header.Set("X-Forwarded-For", "192.0.3.1")
88+
req.Header.Set("X-Forwarded-Host", "example.org")
89+
req.Header.Set("X-Forwarded-Proto", "https")
90+
rr := httptest.NewRecorder()
91+
handler.ServeHTTP(rr, req)
92+
93+
require.Equal(t, http.StatusOK, rr.Code)
94+
var header http.Header
95+
require.NoError(t, json.Unmarshal(rr.Body.Bytes(), &header))
96+
require.Equal(t, "192.0.2.1", header.Get(("X-Forwarded-For")))
97+
require.Equal(t, "example.com", header.Get(("X-Forwarded-Host")))
98+
require.Equal(t, "http", header.Get(("X-Forwarded-Proto")))
99+
}
100+
101+
func TestTransparentBackendRun(t *testing.T) {
102+
r := gin.New()
103+
r.GET("/", func(c *gin.Context) {
104+
c.JSON(http.StatusOK, c.Request.Header)
105+
})
106+
ts := httptest.NewServer(r)
107+
u, _ := url.Parse(ts.URL)
108+
109+
be, err := NewTransparentBackend(zap.NewNop(), u, []string{})
110+
require.NoError(t, err)
111+
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
112+
_, err = be.Run(ctx)
113+
require.NoError(t, err)
114+
115+
checkCh := make(chan struct{})
116+
go func() {
117+
<-ctx.Done()
118+
close(checkCh)
119+
}()
120+
121+
timeout := time.After(10 * time.Millisecond)
122+
select {
123+
case <-checkCh:
124+
t.Error("Test completed too early")
125+
case <-timeout:
126+
// Test timed out
127+
}
128+
129+
cancel()
130+
131+
timeout = time.After(10 * time.Second)
132+
select {
133+
case <-checkCh:
134+
// Test completed successfully
135+
case <-timeout:
136+
t.Error("Test timed out")
137+
}
138+
}

pkg/mcp-proxy/main.go

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@ import (
55
"errors"
66
"fmt"
77
"net/http"
8-
"net/http/httputil"
98
"net/url"
109
"os"
1110
"os/signal"
@@ -59,6 +58,7 @@ func Run(
5958
oidcAllowedUsers []string,
6059
password string,
6160
passwordHash string,
61+
trustedProxy []string,
6262
proxyHeaders []string,
6363
proxyBearerToken string,
6464
proxyTarget []string,
@@ -98,10 +98,18 @@ func Run(
9898
if len(proxyTarget) == 0 {
9999
return fmt.Errorf("proxy target must be specified")
100100
}
101-
var be *backend.ProxyBackend
101+
var be backend.Backend
102102
var beHandler http.Handler
103103
if proxyURL, err := url.Parse(proxyTarget[0]); err == nil && (proxyURL.Scheme == "http" || proxyURL.Scheme == "https") {
104-
beHandler = httputil.NewSingleHostReverseProxy(proxyURL)
104+
var err error
105+
be, err = backend.NewTransparentBackend(logger, proxyURL, trustedProxy)
106+
if err != nil {
107+
return fmt.Errorf("failed to create transparent backend: %w", err)
108+
}
109+
beHandler, err = be.Run(ctx)
110+
if err != nil {
111+
return fmt.Errorf("failed to create transparent backend: %w", err)
112+
}
105113
} else {
106114
be = backend.NewProxyBackend(logger, proxyTarget)
107115
beHandler, err = be.Run(ctx)
@@ -205,6 +213,7 @@ func Run(
205213
}
206214

207215
router := gin.New()
216+
router.SetTrustedProxies(trustedProxy)
208217

209218
router.Use(ginzap.Ginzap(logger, time.RFC3339, true))
210219
router.Use(ginzap.RecoveryWithZap(logger, true))

0 commit comments

Comments
 (0)