Skip to content

Commit

Permalink
windows: wait for control actions before returning
Browse files Browse the repository at this point in the history
Also buffer the OS signal so it's not potentially lost during Run.
  • Loading branch information
djdv committed May 12, 2021
1 parent ef35c56 commit ea2447d
Show file tree
Hide file tree
Showing 2 changed files with 199 additions and 71 deletions.
16 changes: 11 additions & 5 deletions service_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
package service_test

import (
"fmt"
"os"
"testing"
"time"
Expand All @@ -22,22 +23,27 @@ func TestRunInterrupt(t *testing.T) {
t.Fatalf("New err: %s", err)
}

retChan := make(chan error)
go func() {
if err = s.Run(); err != nil {
retChan <- fmt.Errorf("Run() err: %w", err)
}
}()
go func() {
<-time.After(1 * time.Second)
interruptProcess(t)
}()

go func() {
for i := 0; i < 25 && p.numStopped == 0; i++ {
<-time.After(200 * time.Millisecond)
}
if p.numStopped == 0 {
t.Fatal("Run() hasn't been stopped")
retChan <- fmt.Errorf("Run() hasn't been stopped")
}
retChan <- nil
}()

if err = s.Run(); err != nil {
t.Fatalf("Run() err: %s", err)
if err = <-retChan; err != nil {
t.Fatal(err)
}
}

Expand Down
254 changes: 188 additions & 66 deletions service_windows.go
Original file line number Diff line number Diff line change
Expand Up @@ -160,48 +160,63 @@ func (ws *windowsService) getError() error {
return ws.stopStartErr
}

func (ws *windowsService) Execute(args []string, r <-chan svc.ChangeRequest, changes chan<- svc.Status) (bool, uint32) {
const cmdsAccepted = svc.AcceptStop | svc.AcceptShutdown
func (ws *windowsService) Execute(args []string, r <-chan svc.ChangeRequest, changes chan<- svc.Status) (ssec bool, exitCode uint32) {
const exitFailure = 1

// Signal that we're starting.
changes <- svc.Status{State: svc.StartPending}

if err := ws.i.Start(ws); err != nil {
ws.setError(err)
return true, 1
// Perform the actual start.
if initErr := ws.i.Start(ws); initErr != nil {
ws.setError(initErr)
exitCode = exitFailure
return
}

// Signal that we're ready.
changes <- svc.Status{
State: svc.Running,
Accepts: svc.AcceptStop | svc.AcceptShutdown,
}

changes <- svc.Status{State: svc.Running, Accepts: cmdsAccepted}
// Expect service change requests.
var (
exitFunc func(s Service) error
runErr error
)
loop:
for {
c := <-r
for c := range r {
switch c.Cmd {
case svc.Interrogate:
changes <- c.CurrentStatus
case svc.Stop:
changes <- svc.Status{State: svc.StopPending}
if err := ws.i.Stop(ws); err != nil {
ws.setError(err)
return true, 2
}
break loop
case svc.Shutdown:
changes <- svc.Status{State: svc.StopPending}
var err error
if wsShutdown, ok := ws.i.(Shutdowner); ok {
err = wsShutdown.Shutdown(ws)
} else {
err = ws.i.Stop(ws)
}
if err != nil {
ws.setError(err)
return true, 2
if shutdowner, ok := ws.i.(Shutdowner); ok {
exitFunc = shutdowner.Shutdown
break loop
}
fallthrough
case svc.Stop:
exitFunc = ws.i.Stop
break loop
default:
continue loop
runErr = fmt.Errorf("unexpected control request: %v", c.Cmd)
exitCode = exitFailure
ws.setError(runErr)
break loop
}
}

// We were requested to stop.
changes <- svc.Status{State: svc.StopPending}
if exitErr := exitFunc(ws); exitErr != nil {
exitCode = exitFailure
if runErr != nil {
exitErr = fmt.Errorf("%s - %w", runErr, exitErr)
}
ws.setError(exitErr)
}

return false, 0
return
}

func (ws *windowsService) Install() error {
Expand Down Expand Up @@ -249,19 +264,55 @@ func (ws *windowsService) Uninstall() error {
return err
}
defer m.Disconnect()

// MSDN:
// "The DeleteService function marks a service for deletion
// from the service control manager database.
// The database entry is not removed until all open handles
// to the service have been closed by calls to the CloseServiceHandle function,
// and the service is not running."
//
// Since we want to try and wait for the delete to actually happen.
// We close this handle manually when appropriate.
s, err := m.OpenService(ws.Name)
if err != nil {
return fmt.Errorf("service %s is not installed", ws.Name)
}
defer s.Close()
err = s.Delete()
if err != nil {

if err = s.Delete(); err != nil {
s.Close()
return err
}
err = eventlog.Remove(ws.Name)
if err != nil {
if err = eventlog.Remove(ws.Name); err != nil {
s.Close()
return fmt.Errorf("RemoveEventLogSource() failed: %s", err)
}

// Service is now marked for deletion by the system.
// Release our handle to it.
if err := s.Close(); err != nil {
return err
}

// Try to get the service handle back.
// If we get an error from the manager,
// we know the service has been deleted.
// Otherwise, we'll block and keep checking
// until the something returns an error, or we give up.
// Since the service is already marked for deletion,
// we don't consider the unblocking condition to be an error.
// But the service will still exist in the service manager's scope.
// And the caller of Uninstall will be on their own from there.
for attempts := 10; attempts != 0; attempts-- {
s, err := m.OpenService(ws.Name)
if err != nil {
break // expected
}
if err := s.Close(); err != nil {
return err
}
time.Sleep(100 * time.Millisecond)
}
return nil
}

Expand All @@ -287,7 +338,7 @@ func (ws *windowsService) Run() error {
return err
}

sigChan := make(chan os.Signal)
sigChan := make(chan os.Signal, 1)

signal.Notify(sigChan, os.Interrupt)

Expand Down Expand Up @@ -349,26 +400,20 @@ func (ws *windowsService) Start() error {
return err
}
defer s.Close()
return s.Start()
}

func (ws *windowsService) Stop() error {
m, err := mgr.Connect()
if err != nil {
if err = maybeWaitForPending(s); err != nil {
return err
}
defer m.Disconnect()

s, err := m.OpenService(ws.Name)
if err != nil {
return err
initErr := s.Start()
if initErr != nil {
return initErr
}
defer s.Close()

return ws.stopWait(s)
return maybeWaitForPending(s)
}

func (ws *windowsService) Restart() error {
func (ws *windowsService) Stop() error {
m, err := mgr.Connect()
if err != nil {
return err
Expand All @@ -381,42 +426,119 @@ func (ws *windowsService) Restart() error {
}
defer s.Close()

err = ws.stopWait(s)
if err != nil {
if err = maybeWaitForPending(s); err != nil {
return err
}

if _, err = s.Control(svc.Stop); err != nil {
return err
}

return s.Start()
return maybeWaitForPending(s)
}

func (ws *windowsService) stopWait(s *mgr.Service) error {
// First stop the service. Then wait for the service to
// actually stop before starting it.
status, err := s.Control(svc.Stop)
if err != nil {
func (ws *windowsService) Restart() error {
if err := ws.Stop(); err != nil {
return err
}
return ws.Start()
}

// statusInterval retreives a (bounded) duration from the status,
// or provides a default.
func statusInterval(status svc.Status) time.Duration {
// MSDN:
// "Do not wait longer than the wait hint. A good interval is
// one-tenth of the wait hint but not less than 1 second
// and not more than 10 seconds."
const (
lower = time.Second
upper = time.Second * 10
)

waitDuration := (time.Duration(status.WaitHint) * time.Millisecond) / 10
if waitDuration < lower {
waitDuration = lower
} else if waitDuration > upper {
waitDuration = upper
}
return waitDuration
}

// waitForStateChange polls the service until its state matches the desiredState,
// and error is encountered, or we timeout.
func waitForStateChange(s *mgr.Service, currentStatus svc.Status, desiredState svc.State) error {
const defaultAttempts = 10
var (
initialInterval = statusInterval(currentStatus)
queryTicker = time.NewTicker(initialInterval)
queryTimer *time.Timer
)
// If the service is providing hints,
// use them, otherwise use a default timeout.
if currentStatus.CheckPoint != 0 {
queryTimer = time.NewTimer(initialInterval)
} else {
queryTimer = time.NewTimer(initialInterval * defaultAttempts)
}
defer func() {
queryTicker.Stop()
queryTimer.Stop()
}()

var (
currentState = currentStatus.State
lastCheckpoint uint32
)
for currentState != desiredState {
select {
case <-queryTicker.C:
currentStatus, queryErr := s.Query()
if queryErr != nil {
return queryErr
}

timeDuration := time.Millisecond * 50

timeout := time.After(getStopTimeout() + (timeDuration * 2))
tick := time.NewTicker(timeDuration)
defer tick.Stop()
currentState = currentStatus.State
if currentState == desiredState {
return nil
}

for status.State != svc.Stopped {
select {
case <-tick.C:
status, err = s.Query()
if err != nil {
return err
if currentStatus.CheckPoint > lastCheckpoint {
// Service progressed,
// give it more time to complete.
if !queryTimer.Stop() {
<-queryTimer.C
}
queryTimer.Reset(statusInterval(currentStatus))
}
case <-timeout:
break
lastCheckpoint = currentStatus.CheckPoint
case <-queryTimer.C:
return fmt.Errorf("service did not enter desired state (%v) before we timed out",
desiredState)
}
}
return nil
}

func maybeWaitForPending(s *mgr.Service) error {
status, err := s.Query()
if err != nil {
return err
}

var wantState svc.State
switch status.State {
case svc.StartPending:
wantState = svc.Running
case svc.StopPending:
wantState = svc.Stopped
default:
return nil
}

return waitForStateChange(s, status, wantState)
}

// getStopTimeout fetches the time before windows will kill the service.
func getStopTimeout() time.Duration {
// For default and paths see https://support.microsoft.com/en-us/kb/146092
Expand Down

0 comments on commit ea2447d

Please sign in to comment.