Skip to content

Commit

Permalink
Add single start per running monitor plus tests
Browse files Browse the repository at this point in the history
  • Loading branch information
devzbysiu committed Dec 12, 2024
1 parent a351c1d commit 6efd70a
Show file tree
Hide file tree
Showing 2 changed files with 127 additions and 4 deletions.
37 changes: 35 additions & 2 deletions meshnet/monitor.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,10 @@ package meshnet

import (
"context"
"errors"
"log"
"sync"
"sync/atomic"

"github.com/NordSecurity/nordvpn-linux/internal"
"github.com/vishvananda/netlink"
Expand Down Expand Up @@ -32,8 +35,9 @@ type EventHandler interface {
// NetlinkProcessMonitor monitors EXEC and EXIT events of processes and calls
// [EventHandler.OnProcessStarted] and [EventHandler.OnProcessStopped] accordingly.
type NetlinkProcessMonitor struct {
handler EventHandler
setup SetupFn
handler EventHandler
setup SetupFn
isRunning atomic.Bool
}

func NewProcMonitor(handler EventHandler, setup SetupFn) NetlinkProcessMonitor {
Expand All @@ -48,17 +52,46 @@ func NewProcMonitor(handler EventHandler, setup SetupFn) NetlinkProcessMonitor {
// It recreates the source of the events by calling [SetupFn]
// every time [NetlinkProcessMonitor.Start] is called.
func (pm *NetlinkProcessMonitor) Start(ctx context.Context) error {
if !pm.isRunning.CompareAndSwap(false, true) {
return errors.New("monitoring already started for this instance")
}

var monitoringStarted sync.WaitGroup
monitoringStarted.Add(1)
err := pm.start(ctx, &monitoringStarted)
if err != nil {
pm.isRunning.CompareAndSwap(true, false)
return err
}
monitoringStarted.Wait()
return nil
}

func (pm *NetlinkProcessMonitor) start(ctx context.Context, monitoringStarted *sync.WaitGroup) error {
channels, err := pm.setup()
if err != nil {
return err
}

go func() {
monitoringStarted.Done()
// check if the work was cancelled before the loop even started
// and mark the monitor as not running
select {
case <-ctx.Done():
pm.isRunning.CompareAndSwap(true, false)
channels.DoneCh <- struct{}{}
return
default:
// continue to main loop
}

for {
select {
case ev := <-channels.EventCh:
pm.handleProcessEvent(&ev)
case <-ctx.Done():
pm.isRunning.CompareAndSwap(true, false)
channels.DoneCh <- struct{}{}
return
case err := <-channels.ErrCh:
Expand Down
94 changes: 92 additions & 2 deletions meshnet/monitor_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,14 +35,104 @@ func TestNetlinkProcessMonitor_Start(t *testing.T) {
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
monitor := NewProcMonitor(eventHandlerDummy{}, tt.setupFn)
ctx, _ := context.WithCancel(context.Background())
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
err := monitor.Start(ctx)

assert.Equal(t, err != nil, tt.isError)
})
}
}

func TestNetlinkProcessMonitor_Start_AllowedOnlyOncePerRunningMonitor(t *testing.T) {
category.Set(t, category.Unit)
channels, setup := openChannelsMonitorSetup()
monitor := NewProcMonitor(eventHandlerDummy{}, setup)
ctx, cancel := context.WithCancel(context.Background())

// starting first time - fine
err := monitor.Start(ctx)
assert.Nil(t, err)

// not allowed until monitor is running
err = monitor.Start(ctx)
assert.NotNil(t, err)

cancel()

select {
case <-channels.DoneCh:
// cancellation done
case <-time.After(time.Second):
t.Fatal("timeout waiting for monitor to acknowledge cancellation")
}

// allowed again after cancellation
err = monitor.Start(ctx)
assert.Nil(t, err)
}

func TestNetlinkProcessMonitor_Start_RevertsOnSetupFailure(t *testing.T) {
category.Set(t, category.Unit)
monitor := NewProcMonitor(eventHandlerDummy{}, failingSetup)
ctx, cancel := context.WithCancel(context.Background())
defer cancel()

err := monitor.Start(ctx)

assert.NotNil(t, err)
}

func TestNetlinkProcessMonitor_Start_RevertsOnImmediateCancel(t *testing.T) {
category.Set(t, category.Unit)
monitor := NewProcMonitor(eventHandlerDummy{}, workingSetup)
ctx, cancel := context.WithCancel(context.Background())
err := monitor.Start(ctx)
assert.Nil(t, err)

// immediately cancel
cancel()

ticker := time.NewTicker(10 * time.Millisecond)
defer ticker.Stop()

timeout := time.NewTimer(1 * time.Second)
defer timeout.Stop()

// busy wait for max of `timeout` to check that `isRunning` was reverted
CheckLoop:
for {
select {
case <-ticker.C:
if !monitor.isRunning.Load() {
break CheckLoop // exit outer loop, not select statement
}
case <-timeout.C:
t.Fatal("isRunning did not revert to false after context cancellation")
}
}
}

func TestNetlinkProcessMonitor_StartStop(t *testing.T) {
category.Set(t, category.Unit)
channels, setupFn := openChannelsMonitorSetup()
monitor := NewProcMonitor(eventHandlerDummy{}, setupFn)
ctx, cancel := context.WithCancel(context.Background())

err := monitor.Start(ctx)
assert.Nil(t, err)
assert.True(t, monitor.isRunning.Load())

cancel()

select {
case <-channels.DoneCh:
assert.False(t, monitor.isRunning.Load())
case <-time.After(time.Second):
t.Fatal("timeout waiting for monitor to stop")
}
}

func TestNetlinkProcessMonitor_EventHandler(t *testing.T) {
category.Set(t, category.Unit)

Expand Down Expand Up @@ -206,7 +296,7 @@ func failingSetup() (MonitorChannels, error) {
func workingSetup() (MonitorChannels, error) {
return MonitorChannels{
EventCh: make(chan netlink.ProcEvent),
DoneCh: make(chan struct{}),
DoneCh: make(chan struct{}, 1),
ErrCh: make(chan error),
}, nil
}
Expand Down

0 comments on commit 6efd70a

Please sign in to comment.