From f66f10b2ad075909a861ea3efcf9aa25ce288b69 Mon Sep 17 00:00:00 2001 From: Fuyang Liu Date: Thu, 16 Nov 2023 16:22:52 +0100 Subject: [PATCH] feat(run): provides RunWithGracefulShutdownHook So to allow user easily wire up a clean up function that can be called BEFORE the root context is done. --- run.go | 100 ++++++++++++++++++++++++++++++++++++++++++++++++++-- run_test.go | 84 +++++++++++++++++++++++++++++++++++++++---- 2 files changed, 175 insertions(+), 9 deletions(-) diff --git a/run.go b/run.go index f2cb9988..785efb4f 100644 --- a/run.go +++ b/run.go @@ -9,6 +9,7 @@ import ( "os/signal" "runtime/debug" "syscall" + "time" "go.einride.tech/cloudrunner/cloudclient" "go.einride.tech/cloudrunner/cloudconfig" @@ -27,6 +28,14 @@ import ( "google.golang.org/grpc" ) +// gracefulShutdownMaxGracePeriod is the maximum time we wait for the service to finish calling its cancel function +// after a SIGTERM/SIGINT is sent to us. +// If user is using cloudrunner in a Kubernetes like environment, make sure to set `terminationGracePeriodSeconds` +// (default as 30 seconds) above this value to make sure Kubernetes can wait for enough time for graceful shutdown. +// More info see here: +// https://cloud.google.com/blog/products/containers-kubernetes/kubernetes-best-practices-terminating-with-grace +const gracefulShutdownMaxGracePeriod = time.Second * 10 + // runConfig configures the Run entrypoint from environment variables. type runConfig struct { // Runtime contains runtime config. @@ -49,8 +58,41 @@ type runConfig struct { // Run a service. // Configuration of the service is loaded from the environment. +// +// Example usage code can be like: +// +// err := cloudrunner.Run(func(ctx context.Context) error { +// grpcServer := cloudrunner.NewGRPCServer(ctx) +// return cloudrunner.ListenGRPC(ctx, grpcServer) +// }) func Run(fn func(context.Context) error, options ...Option) (err error) { - ctx, cancel := signal.NotifyContext(context.Background(), syscall.SIGINT, syscall.SIGTERM) + noShutdownHook := func(ctx context.Context, _ *ShutdownHook) error { + return fn(ctx) + } + return RunWithGracefulShutdownHook(noShutdownHook, options...) +} + +// RunWithGracefulShutdownHook runs a service and provides a hook ShutdownHook where uer can call to register a +// cancel function that will be called before canceling the root context. +// Root context will be canceled if the registered cancel functions runs for time longer than +// gracefulShutdownMaxGracePeriod. +// Configuration of the service is loaded from the environment. +// +// Example usage code can be like: +// +// err := cloudrunner.RunWithGracefulShutdownHook(func(ctx context.Context, hook *cloudrunner.ShutdownHook) error { +// grpcServer := cloudrunner.NewGRPCServer(ctx) +// hook.HookCancelFunc(func() { +// grpcServer.Stop() +// // or clean up any other resources here +// }) +// return cloudrunner.ListenGRPC(ctx, grpcServer) +// }) +func RunWithGracefulShutdownHook( + fn func(ctx context.Context, shutdownHook *ShutdownHook) error, + options ...Option, +) (err error) { + ctx, cancel := context.WithCancel(context.Background()) defer cancel() usage := flag.Bool("help", false, "show help then exit") yamlServiceSpecificationFile := flag.String("config", "", "load environment from a YAML service specification") @@ -152,7 +194,61 @@ func Run(fn func(context.Context) error, options ...Option) (err error) { ) } }() - return fn(ctx) + + hook := &ShutdownHook{ + rootCtxCancelFunc: cancel, + } + go hook.trapShutdownSignal(ctx, logger) + return fn(ctx, hook) +} + +// ShutdownHook is used for CloudRunner to gracefully shutdown. It makes sure shutdownFunc is called before +// rootCtxCancelFunc is called. +type ShutdownHook struct { + rootCtxCancelFunc func() + shutdownFunc func() +} + +// HookCancelFunc can wire up a cancel function which will be called when SIGTERM is received, and before the root +// context is canceled. +func (s *ShutdownHook) HookCancelFunc(cancel func()) { + s.shutdownFunc = cancel +} + +// trapShutdownSignal blocks and waits for shutdown signal, if received, call s.shutdownFunc() then shutdown. +// +//nolint:lll +func (s *ShutdownHook) trapShutdownSignal(ctx context.Context, logger *zap.Logger) { + logger.Info("watching for termination signals") + sigChan := make(chan os.Signal, 1) + signal.Notify(sigChan, syscall.SIGTERM) + signal.Notify(sigChan, syscall.SIGINT) + + // block and wait for a shutdown signal + sig := <-sigChan + logger.Info("got signal:", zap.String("signal", sig.String())) + if s.shutdownFunc == nil { + logger.Info( + "ShutdownHook is not used. Canceling root context directly. Call RunWithGracefulShutdownHook(...) to enable graceful shutdown if preferred.", + ) + s.rootCtxCancelFunc() + return + } + + // initiate graceful shutdown by calling s.shutdownFunc() + logger.Info("graceful shutdown has begun") + gracefulPeriodCtx, gracefulPeriodCtxCancel := context.WithTimeout(ctx, gracefulShutdownMaxGracePeriod) + go func() { + s.shutdownFunc() + logger.Info("ShutdownHook.shutdownFunc() has finished, meaning we will shutdown cleanly") + gracefulPeriodCtxCancel() + }() + + // block and wait until s.shutdownFunc() finish or gracefulPeriodCtx timeout. + <-gracefulPeriodCtx.Done() + logger.Info("exiting by canceling root context due to shutdown signal") + + s.rootCtxCancelFunc() } type runContext struct { diff --git a/run_test.go b/run_test.go index 7c34076d..44e64ecc 100644 --- a/run_test.go +++ b/run_test.go @@ -2,28 +2,98 @@ package cloudrunner_test import ( "context" + "flag" "log" + "os" + "sync" + "syscall" + "testing" + "time" "go.einride.tech/cloudrunner" "google.golang.org/grpc/health" "google.golang.org/grpc/health/grpc_health_v1" + "gotest.tools/v3/assert" ) -func ExampleRun_helloWorld() { - if err := cloudrunner.Run(func(ctx context.Context) error { +func Test_Run_helloWorld(t *testing.T) { + flag.CommandLine = flag.NewFlagSet(os.Args[0], flag.ContinueOnError) + err := cloudrunner.Run(func(ctx context.Context) error { cloudrunner.Logger(ctx).Info("hello world") return nil - }); err != nil { - log.Fatal(err) - } + }) + + assert.NilError(t, err) } -func ExampleRun_gRPCServer() { - if err := cloudrunner.Run(func(ctx context.Context) error { +func Test_Run_gRPCServer(t *testing.T) { + flag.CommandLine = flag.NewFlagSet(os.Args[0], flag.ContinueOnError) + err := cloudrunner.Run(func(ctx context.Context) error { grpcServer := cloudrunner.NewGRPCServer(ctx) healthServer := health.NewServer() grpc_health_v1.RegisterHealthServer(grpcServer, healthServer) + + // For shutdown gRPC server otherwise we get blocked on ListenGRPC + go func() { + time.Sleep(time.Second) + _ = syscall.Kill(syscall.Getpid(), syscall.SIGTERM) + }() return cloudrunner.ListenGRPC(ctx, grpcServer) + }) + + assert.NilError(t, err) +} + +func Test_RunWithGracefulShutdownHook_gRPCServer(t *testing.T) { + flag.CommandLine = flag.NewFlagSet(os.Args[0], flag.ContinueOnError) + err := cloudrunner.RunWithGracefulShutdownHook(func(ctx context.Context, hook *cloudrunner.ShutdownHook) error { + grpcServer := cloudrunner.NewGRPCServer(ctx) + healthServer := health.NewServer() + grpc_health_v1.RegisterHealthServer(grpcServer, healthServer) + hook.HookCancelFunc(func() { + grpcServer.Stop() + healthServer.Shutdown() + }) + + // For shutdown gRPC server otherwise we get blocked on ListenGRPC + go func() { + time.Sleep(time.Second) + _ = syscall.Kill(syscall.Getpid(), syscall.SIGTERM) + }() + return cloudrunner.ListenGRPC(ctx, grpcServer) + }) + + assert.NilError(t, err) +} + +func Test_RunWithGracefulShutdownHook_helloWorld_ctx_cancel_should_before_clean_up_function_call(t *testing.T) { + flag.CommandLine = flag.NewFlagSet(os.Args[0], flag.ContinueOnError) + if err := cloudrunner.RunWithGracefulShutdownHook(func(ctx context.Context, hook *cloudrunner.ShutdownHook) error { + wg := sync.WaitGroup{} + wg.Add(1) + cleanup := func() { + var isRootContextDone bool + select { + case <-ctx.Done(): + isRootContextDone = true + default: + isRootContextDone = false + } + assert.Equal(t, isRootContextDone, false) + wg.Done() + } + + hook.HookCancelFunc(cleanup) + cloudrunner.Logger(ctx).Info("hello world") + + go func() { + // Simulating seeding a SIGTERM call. + time.Sleep(time.Second) + _ = syscall.Kill(syscall.Getpid(), syscall.SIGTERM) + }() + + wg.Wait() + return nil }); err != nil { log.Fatal(err) }