Skip to content

Commit

Permalink
Add the Interface option to Pinger
Browse files Browse the repository at this point in the history
This allows using `pro-bing` to use VRFs interfaces and IPv6 link-local addresses.
Originally from @ilolicon in #32.

commit 16f9286
Author: Matthieu Pignolet <[email protected]>
Date:   Tue Oct 29 15:46:13 2024 +0400

    Remove the un-used function in te `packetConn` interface that did not get removed during the merging process

commit 88bb1f5
Author: ilolicon <[email protected]>
Date:   Wed Apr 12 16:23:06 2023 +0800

    Refactoring the variable name `Iface` to `Interface` and `ifaceIndex` to `ifIndex`(keeping it consistent with `ControlMessage`)

    Signed-off-by: ilolicon <[email protected]>

commit 887b4e2
Author: ilolicon <[email protected]>
Date:   Wed Apr 12 10:39:36 2023 +0800

    feat: interface binding

    Signed-off-by: ilolicon <[email protected]>
  • Loading branch information
MatthieuCoder committed Oct 29, 2024
1 parent 9c994ed commit 5284d2f
Show file tree
Hide file tree
Showing 4 changed files with 78 additions and 19 deletions.
7 changes: 6 additions & 1 deletion cmd/ping/ping.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ import (
var usage = `
Usage:
ping [-c count] [-i interval] [-t timeout] [--privileged] host
ping [-c count] [-i interval] [-t timeout] [-I interface] [--privileged] host
Examples:
Expand All @@ -29,6 +29,9 @@ Examples:
# ping google for 10 seconds
ping -t 10s www.google.com
# ping google specified interface
ping -I eth1 www.goole.com
# Send a privileged raw ICMP ping
sudo ping --privileged www.google.com
Expand All @@ -42,6 +45,7 @@ func main() {
count := flag.Int("c", -1, "")
size := flag.Int("s", 24, "")
ttl := flag.Int("l", 64, "TTL")
iface := flag.String("I", "", "interface name")
privileged := flag.Bool("privileged", false, "")
flag.Usage = func() {
fmt.Print(usage)
Expand Down Expand Up @@ -90,6 +94,7 @@ func main() {
pinger.Interval = *interval
pinger.Timeout = *timeout
pinger.TTL = *ttl
pinger.Interface = *iface
pinger.SetPrivileged(*privileged)

fmt.Printf("PING %s (%s):\n", pinger.Addr(), pinger.IPAddr())
Expand Down
57 changes: 40 additions & 17 deletions packetconn.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,11 +21,13 @@ type packetConn interface {
SetMark(m uint) error
SetDoNotFragment() error
SetBroadcastFlag() error
SetIfIndex(ifIndex int)
}

type icmpConn struct {
c *icmp.PacketConn
ttl int
c *icmp.PacketConn
ttl int
ifIndex int
}

func (c *icmpConn) Close() error {
Expand All @@ -36,23 +38,12 @@ func (c *icmpConn) SetTTL(ttl int) {
c.ttl = ttl
}

func (c *icmpConn) SetReadDeadline(t time.Time) error {
return c.c.SetReadDeadline(t)
func (c *icmpConn) SetIfIndex(ifIndex int) {
c.ifIndex = ifIndex
}

func (c *icmpConn) WriteTo(b []byte, dst net.Addr) (int, error) {
if c.c.IPv6PacketConn() != nil {
if err := c.c.IPv6PacketConn().SetHopLimit(c.ttl); err != nil {
return 0, err
}
}
if c.c.IPv4PacketConn() != nil {
if err := c.c.IPv4PacketConn().SetTTL(c.ttl); err != nil {
return 0, err
}
}

return c.c.WriteTo(b, dst)
func (c *icmpConn) SetReadDeadline(t time.Time) error {
return c.c.SetReadDeadline(t)
}

type icmpv4Conn struct {
Expand All @@ -76,6 +67,22 @@ func (c *icmpv4Conn) ReadFrom(b []byte) (int, int, net.Addr, error) {
return n, ttl, src, err
}

func (c *icmpv4Conn) WriteTo(b []byte, dst net.Addr) (int, error) {
if err := c.c.IPv4PacketConn().SetTTL(c.ttl); err != nil {
return 0, err
}
var cm *ipv4.ControlMessage
if 1 <= c.ifIndex {
// c.ifIndex == 0 if not set interface
if err := c.c.IPv4PacketConn().SetControlMessage(ipv4.FlagInterface, true); err != nil {
return 0, err
}
cm = &ipv4.ControlMessage{IfIndex: c.ifIndex}
}

return c.c.IPv4PacketConn().WriteTo(b, cm, dst)
}

func (c icmpv4Conn) ICMPRequestType() icmp.Type {
return ipv4.ICMPTypeEcho
}
Expand All @@ -101,6 +108,22 @@ func (c *icmpV6Conn) ReadFrom(b []byte) (int, int, net.Addr, error) {
return n, ttl, src, err
}

func (c *icmpV6Conn) WriteTo(b []byte, dst net.Addr) (int, error) {
if err := c.c.IPv6PacketConn().SetHopLimit(c.ttl); err != nil {
return 0, err
}
var cm *ipv6.ControlMessage
if 1 <= c.ifIndex {
// c.ifIndex == 0 if not set interface
if err := c.c.IPv6PacketConn().SetControlMessage(ipv6.FlagInterface, true); err != nil {
return 0, err
}
cm = &ipv6.ControlMessage{IfIndex: c.ifIndex}
}

return c.c.IPv6PacketConn().WriteTo(b, cm, dst)
}

func (c icmpV6Conn) ICMPRequestType() icmp.Type {
return ipv6.ICMPTypeEchoRequest
}
10 changes: 10 additions & 0 deletions ping.go
Original file line number Diff line number Diff line change
Expand Up @@ -207,6 +207,9 @@ type Pinger struct {
// Source is the source IP address
Source string

// Interface used to send/recv ICMP messages
Interface string

// Channel and mutex used to communicate when the Pinger should stop between goroutines.
done chan interface{}
lock sync.Mutex
Expand Down Expand Up @@ -525,6 +528,13 @@ func (p *Pinger) RunWithContext(ctx context.Context) error {
}

conn.SetTTL(p.TTL)
if p.Interface != "" {
iface, err := net.InterfaceByName(p.Interface)
if err != nil {
return err
}
conn.SetIfIndex(iface.Index)
}
return p.run(ctx, conn)
}

Expand Down
23 changes: 22 additions & 1 deletion ping_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import (
"context"
"errors"
"net"
"runtime"
"runtime/debug"
"sync"
"sync/atomic"
Expand Down Expand Up @@ -473,6 +474,26 @@ func TestStatisticsZeroDivision(t *testing.T) {
}
}

func TestSetInterfaceName(t *testing.T) {
pinger := New("localhost")
pinger.Count = 1
pinger.Timeout = time.Second

// Set loopback interface
pinger.Interface = "lo"
err := pinger.Run()
if runtime.GOOS == "linux" {
AssertNoError(t, err)
} else {
AssertError(t, err, "other platforms unsupport this feature")
}

// Set fake interface
pinger.Interface = "L()0pB@cK"
err = pinger.Run()
AssertError(t, err, "device not found")
}

// Test helpers
func makeTestPinger() *Pinger {
pinger := New("127.0.0.1")
Expand Down Expand Up @@ -644,7 +665,7 @@ func (c testPacketConn) SetTTL(t int) {}
func (c testPacketConn) SetMark(m uint) error { return nil }
func (c testPacketConn) SetDoNotFragment() error { return nil }
func (c testPacketConn) SetBroadcastFlag() error { return nil }

func (c testPacketConn) SetIfIndex(ifIndex int) {}
func (c testPacketConn) ReadFrom(b []byte) (n int, ttl int, src net.Addr, err error) {
return 0, 0, testAddr, nil
}
Expand Down

0 comments on commit 5284d2f

Please sign in to comment.