diff --git a/instrumentation/github.com/gorilla/mux/otelmux/mux.go b/instrumentation/github.com/gorilla/mux/otelmux/mux.go index c7b2355eca8..0e1168ee20e 100644 --- a/instrumentation/github.com/gorilla/mux/otelmux/mux.go +++ b/instrumentation/github.com/gorilla/mux/otelmux/mux.go @@ -4,7 +4,9 @@ package otelmux // import "go.opentelemetry.io/contrib/instrumentation/github.com/gorilla/mux/otelmux" import ( + "bufio" "fmt" + "net" "net/http" "sync" @@ -76,6 +78,13 @@ type recordingResponseWriter struct { status int } +func (h *recordingResponseWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) { + if hijacker, ok := h.writer.(http.Hijacker); ok { + return hijacker.Hijack() + } + return nil, nil, fmt.Errorf("underlying ResponseWriter does not support hijacking") +} + var rrwPool = &sync.Pool{ New: func() interface{} { return &recordingResponseWriter{} diff --git a/instrumentation/github.com/gorilla/mux/otelmux/test/mux_test.go b/instrumentation/github.com/gorilla/mux/otelmux/test/mux_test.go index d5e489b2eb6..859b61fcff5 100644 --- a/instrumentation/github.com/gorilla/mux/otelmux/test/mux_test.go +++ b/instrumentation/github.com/gorilla/mux/otelmux/test/mux_test.go @@ -23,6 +23,63 @@ import ( "go.opentelemetry.io/otel/trace" ) +func TestRecordingResponseWriterHijackWithMiddleware(t *testing.T) { + // Create a mock HTTP handler + mockHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + hj, ok := w.(http.Hijacker) + require.True(t, ok, "ResponseWriter does not implement http.Hijacker") + + conn, rw, err := hj.Hijack() + require.NoError(t, err) + assert.NotNil(t, conn) + assert.NotNil(t, rw) + + err = conn.Close() + require.NoError(t, err) + }) + + // Wrap the handler with otelmux.Middleware + router := mux.NewRouter() + router.Use(otelmux.Middleware("test-service")) + router.Handle("/hijack", mockHandler) + + // Create a mock HTTP request and response writer + req := httptest.NewRequest("GET", "http://example.com/hijack", nil) + rr := httptest.NewRecorder() + + // Serve the HTTP request using the wrapped handler + router.ServeHTTP(rr, req) + + // Verify the response status + assert.Equal(t, http.StatusOK, rr.Code) +} + +func TestRecordingResponseWriterHijackNonHijackerWithMiddleware(t *testing.T) { + // Create a mock HTTP handler + mockHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + hj, ok := w.(http.Hijacker) + require.False(t, ok, "ResponseWriter should not implement http.Hijacker") + + _, _, err := hj.Hijack() + assert.Error(t, err) + }) + + // Wrap the handler with otelmux.Middleware + router := mux.NewRouter() + router.Use(otelmux.Middleware("test-service")) + router.Handle("/non-hijack", mockHandler) + + // Create a mock HTTP request and response writer + req := httptest.NewRequest("GET", "http://example.com/non-hijack", nil) + rr := httptest.NewRecorder() + + // Serve the HTTP request using the wrapped handler + router.ServeHTTP(rr, req) + + // Verify the response status + assert.Equal(t, http.StatusOK, rr.Code) +} + func TestCustomSpanNameFormatter(t *testing.T) { exporter := tracetest.NewInMemoryExporter()