diff --git a/internal/exporter.go b/internal/exporter.go index b603d40..d9d2ebb 100644 --- a/internal/exporter.go +++ b/internal/exporter.go @@ -135,10 +135,16 @@ func (exporter *Exporter) Serve() error { return web.Serve(exporter.listener, exporter.server, &toolkitFlags, slog.Default()) } -// Shutdown : Properly tear down server +// Shutdown : Properly tear down server tears down server using a default background context func (exporter *Exporter) Shutdown() error { + return exporter.ShutdownWithContext(context.Background()) +} + +// ShutdownWithContext properly tears down the server using the provided context +// This allows for setting timeouts and other context-specific configurations +func (exporter *Exporter) ShutdownWithContext(ctx context.Context) error { if exporter.server != nil { - return exporter.server.Shutdown(context.Background()) + return exporter.server.Shutdown(ctx) } return nil diff --git a/internal/exporter_test.go b/internal/exporter_test.go index 0f307fe..553b4ef 100644 --- a/internal/exporter_test.go +++ b/internal/exporter_test.go @@ -2,6 +2,7 @@ package internal import ( "bytes" + "context" "crypto/ecdsa" "crypto/elliptic" "crypto/rand" @@ -912,6 +913,32 @@ func TestMultipleShutdown(t *testing.T) { assert.NoError(t, err) } +func TestShutdownWithContext(t *testing.T) { + exporter := &Exporter{ListenAddress: "127.0.0.1:4243"} + err := exporter.Listen() + assert.NoError(t, err) + + // Shutdown with normal context + ctx := context.Background() + err = exporter.ShutdownWithContext(ctx) + assert.NoError(t, err) + + // Should be safe to call shutdown again even after server is already shutdown + err = exporter.ShutdownWithContext(ctx) + assert.NoError(t, err) + + // Test with timeout context + exporter = &Exporter{ListenAddress: "127.0.0.1:4244"} + err = exporter.Listen() + assert.NoError(t, err) + + timeoutCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + + err = exporter.ShutdownWithContext(timeoutCtx) + assert.NoError(t, err) +} + func testSinglePEM(t *testing.T, expired float64, notBefore time.Time) { certPath := "/tmp/test.pem" generateCertificate(certPath, notBefore)