Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for setting traffic class on outgoing packets #120

Merged
merged 1 commit into from
Nov 22, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 10 additions & 6 deletions cmd/ping/ping.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,7 @@ import (
probing "github.com/prometheus-community/pro-bing"
)

var usage = `
Usage:

ping [-c count] [-i interval] [-t timeout] [-I interface] [--privileged] host

var examples = `
Examples:

# ping google continuously
Expand All @@ -37,6 +33,9 @@ Examples:

# Send ICMP messages with a 100-byte payload
ping -s 100 1.1.1.1

# Send ICMP messages with DSCP CS4 and ECN bits set to 0
ping -Q 128 8.8.8.8
`

func main() {
Expand All @@ -46,9 +45,13 @@ func main() {
size := flag.Int("s", 24, "")
ttl := flag.Int("l", 64, "TTL")
iface := flag.String("I", "", "interface name")
tclass := flag.Int("Q", 192, "Set Quality of Service related bits in ICMP datagrams (DSCP + ECN bits). Only decimal number supported")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wonder if it would be better to make this a String() so we can parse both decimal and hex values like iputils does.

privileged := flag.Bool("privileged", false, "")
flag.Usage = func() {
fmt.Print(usage)
out := flag.CommandLine.Output()
fmt.Fprintf(out, "Usage of %s:\n", os.Args[0])
flag.PrintDefaults()
fmt.Fprint(out, examples)
}
flag.Parse()

Expand Down Expand Up @@ -96,6 +99,7 @@ func main() {
pinger.TTL = *ttl
pinger.InterfaceName = *iface
pinger.SetPrivileged(*privileged)
pinger.SetTrafficClass(uint8(*tclass))

fmt.Printf("PING %s (%s):\n", pinger.Addr(), pinger.IPAddr())
err = pinger.Run()
Expand Down
9 changes: 9 additions & 0 deletions packetconn.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ type packetConn interface {
SetDoNotFragment() error
SetBroadcastFlag() error
SetIfIndex(ifIndex int)
SetTrafficClass(uint8) error
}

type icmpConn struct {
Expand Down Expand Up @@ -58,6 +59,10 @@ func (c *icmpv4Conn) SetFlagTTL() error {
return err
}

func (c *icmpv4Conn) SetTrafficClass(tclass uint8) error {
return c.c.IPv4PacketConn().SetTOS(int(tclass))
}

func (c *icmpv4Conn) ReadFrom(b []byte) (int, int, net.Addr, error) {
ttl := -1
n, cm, src, err := c.c.IPv4PacketConn().ReadFrom(b)
Expand Down Expand Up @@ -99,6 +104,10 @@ func (c *icmpV6Conn) SetFlagTTL() error {
return err
}

func (c *icmpV6Conn) SetTrafficClass(tclass uint8) error {
return c.c.IPv6PacketConn().SetTrafficClass(int(tclass))
}

func (c *icmpV6Conn) ReadFrom(b []byte) (int, int, net.Addr, error) {
ttl := -1
n, cm, src, err := c.c.IPv6PacketConn().ReadFrom(b)
Expand Down
22 changes: 22 additions & 0 deletions ping.go
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,7 @@ func New(addr string) *Pinger {
protocol: "udp",
awaitingSequences: firstSequence,
TTL: 64,
tclass: 192, // CS6 (network control)
logger: StdLogger{Logger: log.New(log.Writer(), log.Prefix(), log.Flags())},
}
}
Expand Down Expand Up @@ -239,6 +240,9 @@ type Pinger struct {
logger Logger

TTL int

// tclass defines the traffic class (ToS for IPv4) set on outgoing icmp packets
tclass uint8
}

type packet struct {
Expand Down Expand Up @@ -488,6 +492,18 @@ func (p *Pinger) SetDoNotFragment(df bool) {
p.df = df
}

// SetTrafficClass sets the traffic class (type-of-service field for IPv4) field
// value for future outgoing packets.
func (p *Pinger) SetTrafficClass(tc uint8) {
p.tclass = tc
}

// TrafficClass returns the traffic class field (type-of-service field for IPv4)
// value for outgoing packets.
func (p *Pinger) TrafficClass() uint8 {
return p.tclass
}

// Run runs the pinger. This is a blocking function that will exit when it's
// done. If Count or Interval are not specified, it will run continuously until
// it is interrupted.
Expand Down Expand Up @@ -527,6 +543,12 @@ func (p *Pinger) RunWithContext(ctx context.Context) error {
}
}

if p.tclass != 0 {
if err := conn.SetTrafficClass(p.tclass); err != nil {
return fmt.Errorf("error setting traffic class: %v", err)
}
}

conn.SetTTL(p.TTL)
if p.InterfaceName != "" {
iface, err := net.InterfaceByName(p.InterfaceName)
Expand Down
14 changes: 14 additions & 0 deletions ping_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -242,6 +242,7 @@ func TestNewPingerValid(t *testing.T) {
AssertNotEqualStrings(t, "www.google.com", p.IPAddr().String())
AssertTrue(t, isIPv4(p.IPAddr().IP))
AssertFalse(t, p.Privileged())
AssertEquals(t, 192, p.tclass)
// Test that SetPrivileged works
p.SetPrivileged(true)
AssertTrue(t, p.Privileged())
Expand All @@ -253,6 +254,9 @@ func TestNewPingerValid(t *testing.T) {
err = p.SetAddr("ipv6.google.com")
AssertNoError(t, err)
AssertFalse(t, isIPv4(p.IPAddr().IP))
// Test setting traffic class
p.SetTrafficClass(0)
AssertEquals(t, 0, p.tclass)

p = New("localhost")
err = p.Resolve()
Expand Down Expand Up @@ -541,6 +545,14 @@ func AssertEqualStrings(t *testing.T, expected, actual string) {
}
}

func AssertEquals[T comparable](t *testing.T, expected, actual T) {
t.Helper()
if expected != actual {
t.Errorf("Expected %v, got %v, Stack:\n%s",
expected, actual, string(debug.Stack()))
}
}

func AssertNotEqualStrings(t *testing.T, expected, actual string) {
t.Helper()
if expected == actual {
Expand Down Expand Up @@ -666,6 +678,8 @@ func (c testPacketConn) SetMark(m uint) error { return nil }
func (c testPacketConn) SetDoNotFragment() error { return nil }
func (c testPacketConn) SetBroadcastFlag() error { return nil }
func (c testPacketConn) SetIfIndex(ifIndex int) {}
func (c testPacketConn) SetTrafficClass(uint8) error { return nil }

func (c testPacketConn) ReadFrom(b []byte) (n int, ttl int, src net.Addr, err error) {
return 0, 0, testAddr, nil
}
Expand Down