diff --git a/cmd/ping/ping.go b/cmd/ping/ping.go index acb5b89..a9ba7a1 100644 --- a/cmd/ping/ping.go +++ b/cmd/ping/ping.go @@ -10,11 +10,7 @@ import ( probing "github.com/prometheus-community/pro-bing" ) -var usage = ` -Usage: - - ping [-c count] [-i interval] [-t timeout] [-I interface] [--privileged] host - +var examples = ` Examples: # ping google continuously @@ -37,6 +33,9 @@ Examples: # Send ICMP messages with a 100-byte payload ping -s 100 1.1.1.1 + + # Send ICMP messages with DSCP CS4 and ECN bits set to 0 + ping -Q 128 8.8.8.8 ` func main() { @@ -46,9 +45,13 @@ func main() { size := flag.Int("s", 24, "") ttl := flag.Int("l", 64, "TTL") iface := flag.String("I", "", "interface name") + tclass := flag.Int("Q", 192, "Set Quality of Service related bits in ICMP datagrams (DSCP + ECN bits). Only decimal number supported") privileged := flag.Bool("privileged", false, "") flag.Usage = func() { - fmt.Print(usage) + out := flag.CommandLine.Output() + fmt.Fprintf(out, "Usage of %s:\n", os.Args[0]) + flag.PrintDefaults() + fmt.Fprint(out, examples) } flag.Parse() @@ -96,6 +99,7 @@ func main() { pinger.TTL = *ttl pinger.InterfaceName = *iface pinger.SetPrivileged(*privileged) + pinger.SetTrafficClass(uint8(*tclass)) fmt.Printf("PING %s (%s):\n", pinger.Addr(), pinger.IPAddr()) err = pinger.Run() diff --git a/packetconn.go b/packetconn.go index 8528d79..b923a19 100644 --- a/packetconn.go +++ b/packetconn.go @@ -22,6 +22,7 @@ type packetConn interface { SetDoNotFragment() error SetBroadcastFlag() error SetIfIndex(ifIndex int) + SetTrafficClass(uint8) error } type icmpConn struct { @@ -58,6 +59,10 @@ func (c *icmpv4Conn) SetFlagTTL() error { return err } +func (c *icmpv4Conn) SetTrafficClass(tclass uint8) error { + return c.c.IPv4PacketConn().SetTOS(int(tclass)) +} + func (c *icmpv4Conn) ReadFrom(b []byte) (int, int, net.Addr, error) { ttl := -1 n, cm, src, err := c.c.IPv4PacketConn().ReadFrom(b) @@ -99,6 +104,10 @@ func (c *icmpV6Conn) SetFlagTTL() error { return err } +func (c *icmpV6Conn) SetTrafficClass(tclass uint8) error { + return c.c.IPv6PacketConn().SetTrafficClass(int(tclass)) +} + func (c *icmpV6Conn) ReadFrom(b []byte) (int, int, net.Addr, error) { ttl := -1 n, cm, src, err := c.c.IPv6PacketConn().ReadFrom(b) diff --git a/ping.go b/ping.go index b6de420..fe5f366 100644 --- a/ping.go +++ b/ping.go @@ -116,6 +116,7 @@ func New(addr string) *Pinger { protocol: "udp", awaitingSequences: firstSequence, TTL: 64, + tclass: 192, // CS6 (network control) logger: StdLogger{Logger: log.New(log.Writer(), log.Prefix(), log.Flags())}, } } @@ -239,6 +240,9 @@ type Pinger struct { logger Logger TTL int + + // tclass defines the traffic class (ToS for IPv4) set on outgoing icmp packets + tclass uint8 } type packet struct { @@ -488,6 +492,18 @@ func (p *Pinger) SetDoNotFragment(df bool) { p.df = df } +// SetTrafficClass sets the traffic class (type-of-service field for IPv4) field +// value for future outgoing packets. +func (p *Pinger) SetTrafficClass(tc uint8) { + p.tclass = tc +} + +// TrafficClass returns the traffic class field (type-of-service field for IPv4) +// value for outgoing packets. +func (p *Pinger) TrafficClass() uint8 { + return p.tclass +} + // Run runs the pinger. This is a blocking function that will exit when it's // done. If Count or Interval are not specified, it will run continuously until // it is interrupted. @@ -527,6 +543,12 @@ func (p *Pinger) RunWithContext(ctx context.Context) error { } } + if p.tclass != 0 { + if err := conn.SetTrafficClass(p.tclass); err != nil { + return fmt.Errorf("error setting traffic class: %v", err) + } + } + conn.SetTTL(p.TTL) if p.InterfaceName != "" { iface, err := net.InterfaceByName(p.InterfaceName) diff --git a/ping_test.go b/ping_test.go index d26bd82..2b2ba9b 100644 --- a/ping_test.go +++ b/ping_test.go @@ -242,6 +242,7 @@ func TestNewPingerValid(t *testing.T) { AssertNotEqualStrings(t, "www.google.com", p.IPAddr().String()) AssertTrue(t, isIPv4(p.IPAddr().IP)) AssertFalse(t, p.Privileged()) + AssertEquals(t, 192, p.tclass) // Test that SetPrivileged works p.SetPrivileged(true) AssertTrue(t, p.Privileged()) @@ -253,6 +254,9 @@ func TestNewPingerValid(t *testing.T) { err = p.SetAddr("ipv6.google.com") AssertNoError(t, err) AssertFalse(t, isIPv4(p.IPAddr().IP)) + // Test setting traffic class + p.SetTrafficClass(0) + AssertEquals(t, 0, p.tclass) p = New("localhost") err = p.Resolve() @@ -541,6 +545,14 @@ func AssertEqualStrings(t *testing.T, expected, actual string) { } } +func AssertEquals[T comparable](t *testing.T, expected, actual T) { + t.Helper() + if expected != actual { + t.Errorf("Expected %v, got %v, Stack:\n%s", + expected, actual, string(debug.Stack())) + } +} + func AssertNotEqualStrings(t *testing.T, expected, actual string) { t.Helper() if expected == actual { @@ -666,6 +678,8 @@ func (c testPacketConn) SetMark(m uint) error { return nil } func (c testPacketConn) SetDoNotFragment() error { return nil } func (c testPacketConn) SetBroadcastFlag() error { return nil } func (c testPacketConn) SetIfIndex(ifIndex int) {} +func (c testPacketConn) SetTrafficClass(uint8) error { return nil } + func (c testPacketConn) ReadFrom(b []byte) (n int, ttl int, src net.Addr, err error) { return 0, 0, testAddr, nil }