diff --git a/ping.go b/ping.go index ece069b..caa4bd9 100644 --- a/ping.go +++ b/ping.go @@ -53,6 +53,7 @@ package probing import ( "bytes" + "context" "errors" "fmt" "log" @@ -401,6 +402,13 @@ func (p *Pinger) ID() int { // done. If Count or Interval are not specified, it will run continuously until // it is interrupted. func (p *Pinger) Run() error { + return p.RunWithContext(context.Background()) +} + +// RunWithContext runs the pinger with a context. This is a blocking function that will exit when it's +// done or if the context is canceled. If Count or Interval are not specified, it will run continuously until +// it is interrupted. +func (p *Pinger) RunWithContext(ctx context.Context) error { var conn packetConn var err error if p.Size < timeSliceLength+trackerLength { @@ -418,10 +426,10 @@ func (p *Pinger) Run() error { defer conn.Close() conn.SetTTL(p.TTL) - return p.run(conn) + return p.run(ctx, conn) } -func (p *Pinger) run(conn packetConn) error { +func (p *Pinger) run(ctx context.Context, conn packetConn) error { if err := conn.SetFlagTTL(); err != nil { return err } @@ -434,7 +442,16 @@ func (p *Pinger) run(conn packetConn) error { handler() } - var g errgroup.Group + g, ctx := errgroup.WithContext(ctx) + + g.Go(func() error { + select { + case <-ctx.Done(): + p.Stop() + case <-p.done: + } + return nil + }) g.Go(func() error { defer p.Stop() diff --git a/ping_test.go b/ping_test.go index 87fe448..c7baf7f 100644 --- a/ping_test.go +++ b/ping_test.go @@ -2,9 +2,11 @@ package probing import ( "bytes" + "context" "errors" "net" "runtime/debug" + "sync" "sync/atomic" "testing" "time" @@ -667,7 +669,7 @@ func TestRunBadWrite(t *testing.T) { var conn testPacketConnBadWrite - err = pinger.run(conn) + err = pinger.run(context.Background(), conn) AssertTrue(t, err != nil) stats := pinger.Statistics() @@ -696,7 +698,7 @@ func TestRunBadRead(t *testing.T) { var conn testPacketConnBadRead - err = pinger.run(conn) + err = pinger.run(context.Background(), conn) AssertTrue(t, err != nil) stats := pinger.Statistics() @@ -710,12 +712,15 @@ func TestRunBadRead(t *testing.T) { type testPacketConnOK struct { testPacketConn + m sync.Mutex writeDone int32 buf []byte dst net.Addr } func (c *testPacketConnOK) WriteTo(b []byte, dst net.Addr) (int, error) { + c.m.Lock() + defer c.m.Unlock() c.buf = make([]byte, len(b)) c.dst = dst n := copy(c.buf, b) @@ -724,6 +729,8 @@ func (c *testPacketConnOK) WriteTo(b []byte, dst net.Addr) (int, error) { } func (c *testPacketConnOK) ReadFrom(b []byte) (n int, ttl int, src net.Addr, err error) { + c.m.Lock() + defer c.m.Unlock() if atomic.LoadInt32(&c.writeDone) == 0 { return 0, 0, nil, nil } @@ -749,7 +756,7 @@ func TestRunOK(t *testing.T) { conn := new(testPacketConnOK) - err = pinger.run(conn) + err = pinger.run(context.Background(), conn) AssertTrue(t, err == nil) stats := pinger.Statistics() @@ -762,3 +769,49 @@ func TestRunOK(t *testing.T) { AssertTrue(t, stats.MinRtt >= 10*time.Millisecond) AssertTrue(t, stats.MinRtt <= 12*time.Millisecond) } + +func TestRunWithTimeoutContext(t *testing.T) { + pinger := New("127.0.0.1") + + err := pinger.Resolve() + AssertNoError(t, err) + + conn := new(testPacketConnOK) + + start := time.Now() + ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond) + defer cancel() + err = pinger.run(ctx, conn) + AssertTrue(t, err == nil) + elapsedTime := time.Since(start) + AssertTrue(t, elapsedTime < 10*time.Second) + + stats := pinger.Statistics() + AssertTrue(t, stats != nil) + if stats == nil { + t.FailNow() + } + AssertTrue(t, stats.PacketsSent > 0) + AssertTrue(t, stats.PacketsRecv > 0) +} + +func TestRunWithBackgroundContext(t *testing.T) { + pinger := New("127.0.0.1") + pinger.Count = 10 + pinger.Interval = 100 * time.Millisecond + + err := pinger.Resolve() + AssertNoError(t, err) + + conn := new(testPacketConnOK) + + err = pinger.run(context.Background(), conn) + AssertTrue(t, err == nil) + + stats := pinger.Statistics() + AssertTrue(t, stats != nil) + if stats == nil { + t.FailNow() + } + AssertTrue(t, stats.PacketsRecv == 10) +}