Skip to content

Commit 1e3672c

Browse files
authored
feat: add header-based authentication for HTTP transports (#57)
Adds optional token auth via AUTH_HEADER and AUTH_VALUE env vars. When configured, requests without the correct header receive 401. This is supposed to be used in tests.
1 parent cf3edb9 commit 1e3672c

File tree

2 files changed

+105
-2
lines changed

2 files changed

+105
-2
lines changed

cmd/yardstick-server/main.go

Lines changed: 28 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ package main
22

33
import (
44
"context"
5+
"errors"
56
"flag"
67
"fmt"
78
"log"
@@ -28,11 +29,33 @@ type EchoResponse struct {
2829
var alphanumericRegex = regexp.MustCompile(`^[a-zA-Z0-9]+$`)
2930
var transport string
3031
var port int
32+
var authHeader string
33+
var authValue string
3134

3235
func validateAlphanumeric(input string) bool {
3336
return alphanumericRegex.MatchString(input)
3437
}
3538

39+
func checkAuth(r *http.Request) error {
40+
if authHeader == "" {
41+
return nil
42+
}
43+
if r.Header.Get(authHeader) != authValue {
44+
return errors.New("unauthorized")
45+
}
46+
return nil
47+
}
48+
49+
func authWrapper(next http.Handler) http.Handler {
50+
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
51+
if err := checkAuth(r); err != nil {
52+
http.Error(w, err.Error(), http.StatusUnauthorized)
53+
return
54+
}
55+
next.ServeHTTP(w, r)
56+
})
57+
}
58+
3659
func echoHandler(_ context.Context, _ *mcp.CallToolRequest, params EchoRequest) (*mcp.CallToolResult, EchoResponse, error) {
3760
if !validateAlphanumeric(params.Input) {
3861
return &mcp.CallToolResult{
@@ -98,7 +121,7 @@ func main() {
98121
}, nil)
99122

100123
// Mount the SSE handler at /sse - it will handle both GET (SSE stream) and POST (messages) requests
101-
http.Handle("/sse", handler)
124+
http.Handle("/sse", authWrapper(handler))
102125

103126
// Create server with timeouts to address G114 gosec issue
104127
srv := &http.Server{
@@ -116,7 +139,7 @@ func main() {
116139
return server
117140
}, nil)
118141

119-
http.Handle("/mcp", handler)
142+
http.Handle("/mcp", authWrapper(handler))
120143

121144
// Create server with timeouts to address G114 gosec issue
122145
srv := &http.Server{
@@ -150,4 +173,7 @@ func parseConfig() {
150173
port = intValue
151174
}
152175
}
176+
177+
authHeader = os.Getenv("AUTH_HEADER")
178+
authValue = os.Getenv("AUTH_VALUE")
153179
}

cmd/yardstick-server/main_test.go

Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ package main
22

33
import (
44
"context"
5+
"net/http"
56
"testing"
67

78
"github.com/modelcontextprotocol/go-sdk/mcp"
@@ -111,3 +112,79 @@ func TestEchoResponseCreation(t *testing.T) {
111112
response := EchoResponse{Output: "test123"}
112113
assert.Equal(t, "test123", response.Output)
113114
}
115+
116+
func TestCheckAuth_HeaderAuth(t *testing.T) {
117+
// Save original values
118+
origHeader := authHeader
119+
origValue := authValue
120+
defer func() {
121+
authHeader = origHeader
122+
authValue = origValue
123+
}()
124+
125+
// Set auth config
126+
authHeader = "X-Auth-Token"
127+
authValue = "secret123"
128+
129+
// Create request with correct header
130+
req, err := http.NewRequest(http.MethodGet, "/test", nil)
131+
assert.NoError(t, err)
132+
req.Header.Set("X-Auth-Token", "secret123")
133+
134+
// Should pass authentication
135+
err = checkAuth(req)
136+
assert.NoError(t, err)
137+
}
138+
139+
func TestCheckAuth_HeaderAuth_Fail(t *testing.T) {
140+
// Save original values
141+
origHeader := authHeader
142+
origValue := authValue
143+
defer func() {
144+
authHeader = origHeader
145+
authValue = origValue
146+
}()
147+
148+
// Set auth config
149+
authHeader = "X-Auth-Token"
150+
authValue = "secret123"
151+
152+
// Test with wrong header value
153+
req, err := http.NewRequest(http.MethodGet, "/test", nil)
154+
assert.NoError(t, err)
155+
req.Header.Set("X-Auth-Token", "wrongvalue")
156+
157+
err = checkAuth(req)
158+
assert.Error(t, err)
159+
assert.Equal(t, "unauthorized", err.Error())
160+
161+
// Test with missing header
162+
req2, err := http.NewRequest(http.MethodGet, "/test", nil)
163+
assert.NoError(t, err)
164+
165+
err = checkAuth(req2)
166+
assert.Error(t, err)
167+
assert.Equal(t, "unauthorized", err.Error())
168+
}
169+
170+
func TestCheckAuth_Disabled(t *testing.T) {
171+
// Save original values
172+
origHeader := authHeader
173+
origValue := authValue
174+
defer func() {
175+
authHeader = origHeader
176+
authValue = origValue
177+
}()
178+
179+
// Auth disabled when authHeader is empty
180+
authHeader = ""
181+
authValue = ""
182+
183+
// Create request without any auth header
184+
req, err := http.NewRequest(http.MethodGet, "/test", nil)
185+
assert.NoError(t, err)
186+
187+
// Should pass since auth is disabled
188+
err = checkAuth(req)
189+
assert.NoError(t, err)
190+
}

0 commit comments

Comments
 (0)