diff --git a/cmd/ping/ping.go b/cmd/ping/ping.go index 645f2e2..7adcc09 100644 --- a/cmd/ping/ping.go +++ b/cmd/ping/ping.go @@ -10,11 +10,7 @@ import ( probing "github.com/prometheus-community/pro-bing" ) -var usage = ` -Usage: - - ping [-c count] [-i interval] [-t timeout] [--privileged] host - +var examples = ` Examples: # ping google continuously @@ -34,6 +30,9 @@ Examples: # Send ICMP messages with a 100-byte payload ping -s 100 1.1.1.1 + + # Send ICMP messages with DSCP CS4 and ECN bits set to 0 + ping -Q 128 8.8.8.8 ` func main() { @@ -42,9 +41,13 @@ func main() { count := flag.Int("c", -1, "") size := flag.Int("s", 24, "") ttl := flag.Int("l", 64, "TTL") + tclass := flag.Int("Q", 192, "Set Quality of Service -related bits in ICMP datagrams (DSCP + ECN bits). Only decimal number supported") privileged := flag.Bool("privileged", false, "") flag.Usage = func() { - fmt.Print(usage) + out := flag.CommandLine.Output() + fmt.Fprintf(out, "Usage of %s:\n", os.Args[0]) + flag.PrintDefaults() + fmt.Fprint(out, examples) } flag.Parse() @@ -91,6 +94,7 @@ func main() { pinger.Timeout = *timeout pinger.TTL = *ttl pinger.SetPrivileged(*privileged) + pinger.SetTrafficClass(uint8(*tclass)) fmt.Printf("PING %s (%s):\n", pinger.Addr(), pinger.IPAddr()) err = pinger.Run() diff --git a/packetconn.go b/packetconn.go index c4ca820..af1a737 100644 --- a/packetconn.go +++ b/packetconn.go @@ -21,6 +21,7 @@ type packetConn interface { SetMark(m uint) error SetDoNotFragment() error SetBroadcastFlag() error + SetTrafficClass(uint8) error } type icmpConn struct { @@ -67,6 +68,10 @@ func (c *icmpv4Conn) SetFlagTTL() error { return err } +func (c *icmpv4Conn) SetTrafficClass(tclass uint8) error { + return c.c.IPv4PacketConn().SetTOS(int(tclass)) +} + func (c *icmpv4Conn) ReadFrom(b []byte) (int, int, net.Addr, error) { ttl := -1 n, cm, src, err := c.c.IPv4PacketConn().ReadFrom(b) @@ -92,6 +97,10 @@ func (c *icmpV6Conn) SetFlagTTL() error { return err } +func (c *icmpV6Conn) SetTrafficClass(tclass uint8) error { + return c.c.IPv6PacketConn().SetTrafficClass(int(tclass)) +} + func (c *icmpV6Conn) ReadFrom(b []byte) (int, int, net.Addr, error) { ttl := -1 n, cm, src, err := c.c.IPv6PacketConn().ReadFrom(b) diff --git a/ping.go b/ping.go index 9175192..40cf0c4 100644 --- a/ping.go +++ b/ping.go @@ -116,6 +116,7 @@ func New(addr string) *Pinger { protocol: "udp", awaitingSequences: firstSequence, TTL: 64, + tclass: 192, // CS6 (network control) logger: StdLogger{Logger: log.New(log.Writer(), log.Prefix(), log.Flags())}, } } @@ -236,6 +237,9 @@ type Pinger struct { logger Logger TTL int + + // tclass defines the traffic class (ToS for IPv4) set on outgoing icmp packets + tclass uint8 } type packet struct { @@ -485,6 +489,18 @@ func (p *Pinger) SetDoNotFragment(df bool) { p.df = df } +// SetTrafficClass sets the traffic class (type-of-service field for IPv4) field +// value for future outgoing packets. +func (p *Pinger) SetTrafficClass(tc uint8) { + p.tclass = tc +} + +// TrafficClass returns the traffic class field (type-of-service field for IPv4) +// value for outgoing packets. +func (p *Pinger) TrafficClass() uint8 { + return p.tclass +} + // Run runs the pinger. This is a blocking function that will exit when it's // done. If Count or Interval are not specified, it will run continuously until // it is interrupted. @@ -524,6 +540,12 @@ func (p *Pinger) RunWithContext(ctx context.Context) error { } } + if p.tclass != 0 { + if err := conn.SetTrafficClass(p.tclass); err != nil { + return fmt.Errorf("error setting traffic class: %v", err) + } + } + conn.SetTTL(p.TTL) return p.run(ctx, conn) } diff --git a/ping_test.go b/ping_test.go index bbbfe9b..15e870a 100644 --- a/ping_test.go +++ b/ping_test.go @@ -241,6 +241,7 @@ func TestNewPingerValid(t *testing.T) { AssertNotEqualStrings(t, "www.google.com", p.IPAddr().String()) AssertTrue(t, isIPv4(p.IPAddr().IP)) AssertFalse(t, p.Privileged()) + AssertEquals(t, 192, p.tclass) // Test that SetPrivileged works p.SetPrivileged(true) AssertTrue(t, p.Privileged()) @@ -252,6 +253,9 @@ func TestNewPingerValid(t *testing.T) { err = p.SetAddr("ipv6.google.com") AssertNoError(t, err) AssertFalse(t, isIPv4(p.IPAddr().IP)) + // Test setting traffic class + p.SetTrafficClass(0) + AssertEquals(t, 0, p.tclass) p = New("localhost") err = p.Resolve() @@ -520,6 +524,14 @@ func AssertEqualStrings(t *testing.T, expected, actual string) { } } +func AssertEquals[T comparable](t *testing.T, expected, actual T) { + t.Helper() + if expected != actual { + t.Errorf("Expected %v, got %v, Stack:\n%s", + expected, actual, string(debug.Stack())) + } +} + func AssertNotEqualStrings(t *testing.T, expected, actual string) { t.Helper() if expected == actual { @@ -644,6 +656,7 @@ func (c testPacketConn) SetTTL(t int) {} func (c testPacketConn) SetMark(m uint) error { return nil } func (c testPacketConn) SetDoNotFragment() error { return nil } func (c testPacketConn) SetBroadcastFlag() error { return nil } +func (c testPacketConn) SetTrafficClass(uint8) error { return nil } func (c testPacketConn) ReadFrom(b []byte) (n int, ttl int, src net.Addr, err error) { return 0, 0, testAddr, nil