diff --git a/.gitignore b/.gitignore index 35531439..9f72feeb 100644 --- a/.gitignore +++ b/.gitignore @@ -1,5 +1,6 @@ /.tmp/ *.pprof *.svg +.idea cover.out connect.test diff --git a/internal/memhttp/memhttp.go b/internal/memhttp/memhttp.go index f8046426..49fc3328 100644 --- a/internal/memhttp/memhttp.go +++ b/internal/memhttp/memhttp.go @@ -33,6 +33,11 @@ type Server struct { serverWG sync.WaitGroup serverErr error + + // client is configured for use with the server. + // Its transport is automatically closed when Close is called. + client *http.Client + clientMu sync.Mutex } // NewServer creates a new Server that uses the given handler. Configuration @@ -94,12 +99,18 @@ func (s *Server) TransportHTTP1() *http.Transport { } // Client returns an [http.Client] configured to use in-memory pipes rather -// than TCP and speak HTTP/2. It is configured to use the same -// [http2.Transport] as [Transport]. +// than TCP and speak HTTP/2. // -// Callers may reconfigure the returned client without affecting other clients. +// Client is configured to use the same transport for the lifetime of the +// server, and its idle connections are automatically closed when the +// server is closed. func (s *Server) Client() *http.Client { - return &http.Client{Transport: s.Transport()} + s.clientMu.Lock() + defer s.clientMu.Unlock() + if s.client == nil { + s.client = &http.Client{Transport: s.Transport()} + } + return s.client } // URL returns the server's URL. @@ -110,6 +121,11 @@ func (s *Server) URL() string { // Shutdown gracefully shuts down the server, without interrupting any active // connections. See [http.Server.Shutdown] for details. func (s *Server) Shutdown(ctx context.Context) error { + s.clientMu.Lock() + if s.client != nil { + s.client.CloseIdleConnections() + } + s.clientMu.Unlock() if err := s.server.Shutdown(ctx); err != nil { return err } @@ -128,6 +144,11 @@ func (s *Server) Cleanup() error { // Close closes the server's listener. It does not wait for connections to // finish. func (s *Server) Close() error { + s.clientMu.Lock() + if s.client != nil { + s.client.CloseIdleConnections() + } + s.clientMu.Unlock() return s.server.Close() } diff --git a/internal/memhttp/memhttptest/http_test.go b/internal/memhttp/memhttptest/http_test.go new file mode 100644 index 00000000..8f62422a --- /dev/null +++ b/internal/memhttp/memhttptest/http_test.go @@ -0,0 +1,112 @@ +// Copyright 2021-2025 The Connect Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +//go:build go1.25 + +// Copyright 2021-2025 The Connect Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package memhttptest_test + +import ( + "bytes" + "io" + "net/http" + "strings" + "testing" + "testing/synctest" + + "connectrpc.com/connect/internal/assert" + "connectrpc.com/connect/internal/memhttp" + "connectrpc.com/connect/internal/memhttp/memhttptest" +) + +// TestMemhttpWithSynctest verifies that memhttp works correctly with synctest. +func TestMemhttpWithSynctest(t *testing.T) { + t.Parallel() + body := "request body" + + handler := http.HandlerFunc(func(writer http.ResponseWriter, request *http.Request) { + buf := &bytes.Buffer{} + _, err := io.Copy(buf, request.Body) + if err != nil { + t.Errorf("failed to copy body: %v", err) + } + if buf.String() != body { + t.Errorf("got body %q, want %q", buf.String(), body) + } + writer.WriteHeader(http.StatusOK) + }) + + tests := []struct { + name string + client func(*testing.T, *memhttp.Server) *http.Client + }{ + { + name: "server.Client()", + client: func(t *testing.T, s *memhttp.Server) *http.Client { + t.Helper() + return s.Client() + }, + }, + { + name: "Custom Client HTTP/1", + client: func(t *testing.T, s *memhttp.Server) *http.Client { + t.Helper() + // HTTP/1.1's is a per-request closure, so nothing leaks outside the bubble. + return &http.Client{Transport: s.TransportHTTP1()} + }, + }, + { + name: "Custom Client HTTP/2", + client: func(t *testing.T, s *memhttp.Server) *http.Client { + t.Helper() + // HTTP/2 a goroutine running for future connections, which leaks outside the bubble. + client := &http.Client{Transport: s.Transport()} + // Closing idle connections here ensures synctest doesn't panic. + t.Cleanup(client.CloseIdleConnections) + return client + }, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + t.Parallel() + synctest.Test(t, func(t *testing.T) { + t.Helper() + server := memhttptest.NewServer(t, handler) + + req, err := http.NewRequestWithContext(t.Context(), http.MethodPut, server.URL(), strings.NewReader(body)) + assert.Nil(t, err) + + client := test.client(t, server) + resp, err := client.Do(req) + assert.Nil(t, err) + resp.Body.Close() + }) + }) + } +}