Skip to content

Commit

Permalink
Merge pull request #29 from TheRushingWookie/add_context
Browse files Browse the repository at this point in the history
Signed-off-by: Ben Kochie <[email protected]>
  • Loading branch information
SuperQ authored Apr 11, 2023
2 parents 8ae78da + 8b025e3 commit a17ba03
Show file tree
Hide file tree
Showing 2 changed files with 76 additions and 6 deletions.
23 changes: 20 additions & 3 deletions ping.go
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ package probing

import (
"bytes"
"context"
"errors"
"fmt"
"log"
Expand Down Expand Up @@ -401,6 +402,13 @@ func (p *Pinger) ID() int {
// done. If Count or Interval are not specified, it will run continuously until
// it is interrupted.
func (p *Pinger) Run() error {
return p.RunWithContext(context.Background())
}

// RunWithContext runs the pinger with a context. This is a blocking function that will exit when it's
// done or if the context is canceled. If Count or Interval are not specified, it will run continuously until
// it is interrupted.
func (p *Pinger) RunWithContext(ctx context.Context) error {
var conn packetConn
var err error
if p.Size < timeSliceLength+trackerLength {
Expand All @@ -418,10 +426,10 @@ func (p *Pinger) Run() error {
defer conn.Close()

conn.SetTTL(p.TTL)
return p.run(conn)
return p.run(ctx, conn)
}

func (p *Pinger) run(conn packetConn) error {
func (p *Pinger) run(ctx context.Context, conn packetConn) error {
if err := conn.SetFlagTTL(); err != nil {
return err
}
Expand All @@ -434,7 +442,16 @@ func (p *Pinger) run(conn packetConn) error {
handler()
}

var g errgroup.Group
g, ctx := errgroup.WithContext(ctx)

g.Go(func() error {
select {
case <-ctx.Done():
p.Stop()
case <-p.done:
}
return nil
})

g.Go(func() error {
defer p.Stop()
Expand Down
59 changes: 56 additions & 3 deletions ping_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,11 @@ package probing

import (
"bytes"
"context"
"errors"
"net"
"runtime/debug"
"sync"
"sync/atomic"
"testing"
"time"
Expand Down Expand Up @@ -667,7 +669,7 @@ func TestRunBadWrite(t *testing.T) {

var conn testPacketConnBadWrite

err = pinger.run(conn)
err = pinger.run(context.Background(), conn)
AssertTrue(t, err != nil)

stats := pinger.Statistics()
Expand Down Expand Up @@ -696,7 +698,7 @@ func TestRunBadRead(t *testing.T) {

var conn testPacketConnBadRead

err = pinger.run(conn)
err = pinger.run(context.Background(), conn)
AssertTrue(t, err != nil)

stats := pinger.Statistics()
Expand All @@ -710,12 +712,15 @@ func TestRunBadRead(t *testing.T) {

type testPacketConnOK struct {
testPacketConn
m sync.Mutex
writeDone int32
buf []byte
dst net.Addr
}

func (c *testPacketConnOK) WriteTo(b []byte, dst net.Addr) (int, error) {
c.m.Lock()
defer c.m.Unlock()
c.buf = make([]byte, len(b))
c.dst = dst
n := copy(c.buf, b)
Expand All @@ -724,6 +729,8 @@ func (c *testPacketConnOK) WriteTo(b []byte, dst net.Addr) (int, error) {
}

func (c *testPacketConnOK) ReadFrom(b []byte) (n int, ttl int, src net.Addr, err error) {
c.m.Lock()
defer c.m.Unlock()
if atomic.LoadInt32(&c.writeDone) == 0 {
return 0, 0, nil, nil
}
Expand All @@ -749,7 +756,7 @@ func TestRunOK(t *testing.T) {

conn := new(testPacketConnOK)

err = pinger.run(conn)
err = pinger.run(context.Background(), conn)
AssertTrue(t, err == nil)

stats := pinger.Statistics()
Expand All @@ -762,3 +769,49 @@ func TestRunOK(t *testing.T) {
AssertTrue(t, stats.MinRtt >= 10*time.Millisecond)
AssertTrue(t, stats.MinRtt <= 12*time.Millisecond)
}

func TestRunWithTimeoutContext(t *testing.T) {
pinger := New("127.0.0.1")

err := pinger.Resolve()
AssertNoError(t, err)

conn := new(testPacketConnOK)

start := time.Now()
ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond)
defer cancel()
err = pinger.run(ctx, conn)
AssertTrue(t, err == nil)
elapsedTime := time.Since(start)
AssertTrue(t, elapsedTime < 10*time.Second)

stats := pinger.Statistics()
AssertTrue(t, stats != nil)
if stats == nil {
t.FailNow()
}
AssertTrue(t, stats.PacketsSent > 0)
AssertTrue(t, stats.PacketsRecv > 0)
}

func TestRunWithBackgroundContext(t *testing.T) {
pinger := New("127.0.0.1")
pinger.Count = 10
pinger.Interval = 100 * time.Millisecond

err := pinger.Resolve()
AssertNoError(t, err)

conn := new(testPacketConnOK)

err = pinger.run(context.Background(), conn)
AssertTrue(t, err == nil)

stats := pinger.Statistics()
AssertTrue(t, stats != nil)
if stats == nil {
t.FailNow()
}
AssertTrue(t, stats.PacketsRecv == 10)
}

0 comments on commit a17ba03

Please sign in to comment.