From ea2447d061d2dc77e06915c8c02545685f6b55dc Mon Sep 17 00:00:00 2001 From: Dominic Della Valle Date: Sat, 24 Apr 2021 05:33:34 -0400 Subject: [PATCH] windows: wait for control actions before returning Also buffer the OS signal so it's not potentially lost during Run. --- service_test.go | 16 ++- service_windows.go | 254 +++++++++++++++++++++++++++++++++------------ 2 files changed, 199 insertions(+), 71 deletions(-) diff --git a/service_test.go b/service_test.go index 886b0cfb..e60367a6 100644 --- a/service_test.go +++ b/service_test.go @@ -5,6 +5,7 @@ package service_test import ( + "fmt" "os" "testing" "time" @@ -22,22 +23,27 @@ func TestRunInterrupt(t *testing.T) { t.Fatalf("New err: %s", err) } + retChan := make(chan error) + go func() { + if err = s.Run(); err != nil { + retChan <- fmt.Errorf("Run() err: %w", err) + } + }() go func() { <-time.After(1 * time.Second) interruptProcess(t) - }() - go func() { for i := 0; i < 25 && p.numStopped == 0; i++ { <-time.After(200 * time.Millisecond) } if p.numStopped == 0 { - t.Fatal("Run() hasn't been stopped") + retChan <- fmt.Errorf("Run() hasn't been stopped") } + retChan <- nil }() - if err = s.Run(); err != nil { - t.Fatalf("Run() err: %s", err) + if err = <-retChan; err != nil { + t.Fatal(err) } } diff --git a/service_windows.go b/service_windows.go index c6ee3a6c..895eba70 100644 --- a/service_windows.go +++ b/service_windows.go @@ -160,48 +160,63 @@ func (ws *windowsService) getError() error { return ws.stopStartErr } -func (ws *windowsService) Execute(args []string, r <-chan svc.ChangeRequest, changes chan<- svc.Status) (bool, uint32) { - const cmdsAccepted = svc.AcceptStop | svc.AcceptShutdown +func (ws *windowsService) Execute(args []string, r <-chan svc.ChangeRequest, changes chan<- svc.Status) (ssec bool, exitCode uint32) { + const exitFailure = 1 + + // Signal that we're starting. changes <- svc.Status{State: svc.StartPending} - if err := ws.i.Start(ws); err != nil { - ws.setError(err) - return true, 1 + // Perform the actual start. + if initErr := ws.i.Start(ws); initErr != nil { + ws.setError(initErr) + exitCode = exitFailure + return + } + + // Signal that we're ready. + changes <- svc.Status{ + State: svc.Running, + Accepts: svc.AcceptStop | svc.AcceptShutdown, } - changes <- svc.Status{State: svc.Running, Accepts: cmdsAccepted} + // Expect service change requests. + var ( + exitFunc func(s Service) error + runErr error + ) loop: - for { - c := <-r + for c := range r { switch c.Cmd { case svc.Interrogate: changes <- c.CurrentStatus - case svc.Stop: - changes <- svc.Status{State: svc.StopPending} - if err := ws.i.Stop(ws); err != nil { - ws.setError(err) - return true, 2 - } - break loop case svc.Shutdown: - changes <- svc.Status{State: svc.StopPending} - var err error - if wsShutdown, ok := ws.i.(Shutdowner); ok { - err = wsShutdown.Shutdown(ws) - } else { - err = ws.i.Stop(ws) - } - if err != nil { - ws.setError(err) - return true, 2 + if shutdowner, ok := ws.i.(Shutdowner); ok { + exitFunc = shutdowner.Shutdown + break loop } + fallthrough + case svc.Stop: + exitFunc = ws.i.Stop break loop default: - continue loop + runErr = fmt.Errorf("unexpected control request: %v", c.Cmd) + exitCode = exitFailure + ws.setError(runErr) + break loop + } + } + + // We were requested to stop. + changes <- svc.Status{State: svc.StopPending} + if exitErr := exitFunc(ws); exitErr != nil { + exitCode = exitFailure + if runErr != nil { + exitErr = fmt.Errorf("%s - %w", runErr, exitErr) } + ws.setError(exitErr) } - return false, 0 + return } func (ws *windowsService) Install() error { @@ -249,19 +264,55 @@ func (ws *windowsService) Uninstall() error { return err } defer m.Disconnect() + + // MSDN: + // "The DeleteService function marks a service for deletion + // from the service control manager database. + // The database entry is not removed until all open handles + // to the service have been closed by calls to the CloseServiceHandle function, + // and the service is not running." + // + // Since we want to try and wait for the delete to actually happen. + // We close this handle manually when appropriate. s, err := m.OpenService(ws.Name) if err != nil { return fmt.Errorf("service %s is not installed", ws.Name) } - defer s.Close() - err = s.Delete() - if err != nil { + + if err = s.Delete(); err != nil { + s.Close() return err } - err = eventlog.Remove(ws.Name) - if err != nil { + if err = eventlog.Remove(ws.Name); err != nil { + s.Close() return fmt.Errorf("RemoveEventLogSource() failed: %s", err) } + + // Service is now marked for deletion by the system. + // Release our handle to it. + if err := s.Close(); err != nil { + return err + } + + // Try to get the service handle back. + // If we get an error from the manager, + // we know the service has been deleted. + // Otherwise, we'll block and keep checking + // until the something returns an error, or we give up. + // Since the service is already marked for deletion, + // we don't consider the unblocking condition to be an error. + // But the service will still exist in the service manager's scope. + // And the caller of Uninstall will be on their own from there. + for attempts := 10; attempts != 0; attempts-- { + s, err := m.OpenService(ws.Name) + if err != nil { + break // expected + } + if err := s.Close(); err != nil { + return err + } + time.Sleep(100 * time.Millisecond) + } return nil } @@ -287,7 +338,7 @@ func (ws *windowsService) Run() error { return err } - sigChan := make(chan os.Signal) + sigChan := make(chan os.Signal, 1) signal.Notify(sigChan, os.Interrupt) @@ -349,26 +400,20 @@ func (ws *windowsService) Start() error { return err } defer s.Close() - return s.Start() -} -func (ws *windowsService) Stop() error { - m, err := mgr.Connect() - if err != nil { + if err = maybeWaitForPending(s); err != nil { return err } - defer m.Disconnect() - s, err := m.OpenService(ws.Name) - if err != nil { - return err + initErr := s.Start() + if initErr != nil { + return initErr } - defer s.Close() - return ws.stopWait(s) + return maybeWaitForPending(s) } -func (ws *windowsService) Restart() error { +func (ws *windowsService) Stop() error { m, err := mgr.Connect() if err != nil { return err @@ -381,42 +426,119 @@ func (ws *windowsService) Restart() error { } defer s.Close() - err = ws.stopWait(s) - if err != nil { + if err = maybeWaitForPending(s); err != nil { + return err + } + + if _, err = s.Control(svc.Stop); err != nil { return err } - return s.Start() + return maybeWaitForPending(s) } -func (ws *windowsService) stopWait(s *mgr.Service) error { - // First stop the service. Then wait for the service to - // actually stop before starting it. - status, err := s.Control(svc.Stop) - if err != nil { +func (ws *windowsService) Restart() error { + if err := ws.Stop(); err != nil { return err } + return ws.Start() +} + +// statusInterval retreives a (bounded) duration from the status, +// or provides a default. +func statusInterval(status svc.Status) time.Duration { + // MSDN: + // "Do not wait longer than the wait hint. A good interval is + // one-tenth of the wait hint but not less than 1 second + // and not more than 10 seconds." + const ( + lower = time.Second + upper = time.Second * 10 + ) + + waitDuration := (time.Duration(status.WaitHint) * time.Millisecond) / 10 + if waitDuration < lower { + waitDuration = lower + } else if waitDuration > upper { + waitDuration = upper + } + return waitDuration +} + +// waitForStateChange polls the service until its state matches the desiredState, +// and error is encountered, or we timeout. +func waitForStateChange(s *mgr.Service, currentStatus svc.Status, desiredState svc.State) error { + const defaultAttempts = 10 + var ( + initialInterval = statusInterval(currentStatus) + queryTicker = time.NewTicker(initialInterval) + queryTimer *time.Timer + ) + // If the service is providing hints, + // use them, otherwise use a default timeout. + if currentStatus.CheckPoint != 0 { + queryTimer = time.NewTimer(initialInterval) + } else { + queryTimer = time.NewTimer(initialInterval * defaultAttempts) + } + defer func() { + queryTicker.Stop() + queryTimer.Stop() + }() + + var ( + currentState = currentStatus.State + lastCheckpoint uint32 + ) + for currentState != desiredState { + select { + case <-queryTicker.C: + currentStatus, queryErr := s.Query() + if queryErr != nil { + return queryErr + } - timeDuration := time.Millisecond * 50 - - timeout := time.After(getStopTimeout() + (timeDuration * 2)) - tick := time.NewTicker(timeDuration) - defer tick.Stop() + currentState = currentStatus.State + if currentState == desiredState { + return nil + } - for status.State != svc.Stopped { - select { - case <-tick.C: - status, err = s.Query() - if err != nil { - return err + if currentStatus.CheckPoint > lastCheckpoint { + // Service progressed, + // give it more time to complete. + if !queryTimer.Stop() { + <-queryTimer.C + } + queryTimer.Reset(statusInterval(currentStatus)) } - case <-timeout: - break + lastCheckpoint = currentStatus.CheckPoint + case <-queryTimer.C: + return fmt.Errorf("service did not enter desired state (%v) before we timed out", + desiredState) } } return nil } +func maybeWaitForPending(s *mgr.Service) error { + status, err := s.Query() + if err != nil { + return err + } + + var wantState svc.State + switch status.State { + case svc.StartPending: + wantState = svc.Running + case svc.StopPending: + wantState = svc.Stopped + default: + return nil + } + + return waitForStateChange(s, status, wantState) +} + // getStopTimeout fetches the time before windows will kill the service. func getStopTimeout() time.Duration { // For default and paths see https://support.microsoft.com/en-us/kb/146092