Skip to content

Commit

Permalink
feat: interface binding
Browse files Browse the repository at this point in the history
  • Loading branch information
ilolicon committed Apr 12, 2023
1 parent a17ba03 commit c58280f
Show file tree
Hide file tree
Showing 4 changed files with 78 additions and 18 deletions.
7 changes: 6 additions & 1 deletion cmd/ping/ping.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ import (
var usage = `
Usage:
ping [-c count] [-i interval] [-t timeout] [--privileged] host
ping [-c count] [-i interval] [-t timeout] [-I interface] [--privileged] host
Examples:
Expand All @@ -29,6 +29,9 @@ Examples:
# ping google for 10 seconds
ping -t 10s www.google.com
# ping google specified interface
ping -I eth1 www.goole.com
# Send a privileged raw ICMP ping
sudo ping --privileged www.google.com
Expand All @@ -42,6 +45,7 @@ func main() {
count := flag.Int("c", -1, "")
size := flag.Int("s", 24, "")
ttl := flag.Int("l", 64, "TTL")
iface := flag.String("I", "", "interface name")
privileged := flag.Bool("privileged", false, "")
flag.Usage = func() {
fmt.Print(usage)
Expand Down Expand Up @@ -90,6 +94,7 @@ func main() {
pinger.Interval = *interval
pinger.Timeout = *timeout
pinger.TTL = *ttl
pinger.Iface = *iface
pinger.SetPrivileged(*privileged)

fmt.Printf("PING %s (%s):\n", pinger.Addr(), pinger.IPAddr())
Expand Down
57 changes: 40 additions & 17 deletions packetconn.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,13 @@ type packetConn interface {
SetReadDeadline(t time.Time) error
WriteTo(b []byte, dst net.Addr) (int, error)
SetTTL(ttl int)
SetIfaceIndex(ifaceIndex int)
}

type icmpConn struct {
c *icmp.PacketConn
ttl int
c *icmp.PacketConn
ttl int
ifaceIndex int
}

func (c *icmpConn) Close() error {
Expand All @@ -33,23 +35,12 @@ func (c *icmpConn) SetTTL(ttl int) {
c.ttl = ttl
}

func (c *icmpConn) SetReadDeadline(t time.Time) error {
return c.c.SetReadDeadline(t)
func (c *icmpConn) SetIfaceIndex(ifaceIndex int) {
c.ifaceIndex = ifaceIndex
}

func (c *icmpConn) WriteTo(b []byte, dst net.Addr) (int, error) {
if c.c.IPv6PacketConn() != nil {
if err := c.c.IPv6PacketConn().SetHopLimit(c.ttl); err != nil {
return 0, err
}
}
if c.c.IPv4PacketConn() != nil {
if err := c.c.IPv4PacketConn().SetTTL(c.ttl); err != nil {
return 0, err
}
}

return c.c.WriteTo(b, dst)
func (c *icmpConn) SetReadDeadline(t time.Time) error {
return c.c.SetReadDeadline(t)
}

type icmpv4Conn struct {
Expand All @@ -73,6 +64,22 @@ func (c *icmpv4Conn) ReadFrom(b []byte) (int, int, net.Addr, error) {
return n, ttl, src, err
}

func (c *icmpv4Conn) WriteTo(b []byte, dst net.Addr) (int, error) {
if err := c.c.IPv4PacketConn().SetTTL(c.ttl); err != nil {
return 0, err
}
var cm *ipv4.ControlMessage
if 1 <= c.ifaceIndex {
// c.ifaceIndex == 0 if not set interface
if err := c.c.IPv4PacketConn().SetControlMessage(ipv4.FlagInterface, true); err != nil {
return 0, err
}
cm = &ipv4.ControlMessage{IfIndex: c.ifaceIndex}
}

return c.c.IPv4PacketConn().WriteTo(b, cm, dst)
}

func (c icmpv4Conn) ICMPRequestType() icmp.Type {
return ipv4.ICMPTypeEcho
}
Expand All @@ -98,6 +105,22 @@ func (c *icmpV6Conn) ReadFrom(b []byte) (int, int, net.Addr, error) {
return n, ttl, src, err
}

func (c *icmpV6Conn) WriteTo(b []byte, dst net.Addr) (int, error) {
if err := c.c.IPv6PacketConn().SetHopLimit(c.ttl); err != nil {
return 0, err
}
var cm *ipv6.ControlMessage
if 1 <= c.ifaceIndex {
// c.ifaceIndex == 0 if not set interface
if err := c.c.IPv6PacketConn().SetControlMessage(ipv6.FlagInterface, true); err != nil {
return 0, err
}
cm = &ipv6.ControlMessage{IfIndex: c.ifaceIndex}
}

return c.c.IPv6PacketConn().WriteTo(b, cm, dst)
}

func (c icmpV6Conn) ICMPRequestType() icmp.Type {
return ipv6.ICMPTypeEchoRequest
}
10 changes: 10 additions & 0 deletions ping.go
Original file line number Diff line number Diff line change
Expand Up @@ -182,6 +182,9 @@ type Pinger struct {
// Source is the source IP address
Source string

// Iface used to send/recv ICMP messages
Iface string

// Channel and mutex used to communicate when the Pinger should stop between goroutines.
done chan interface{}
lock sync.Mutex
Expand Down Expand Up @@ -426,6 +429,13 @@ func (p *Pinger) RunWithContext(ctx context.Context) error {
defer conn.Close()

conn.SetTTL(p.TTL)
if p.Iface != "" {
iface, err := net.InterfaceByName(p.Iface)
if err != nil {
return err
}
conn.SetIfaceIndex(iface.Index)
}
return p.run(ctx, conn)
}

Expand Down
22 changes: 22 additions & 0 deletions ping_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import (
"context"
"errors"
"net"
"runtime"
"runtime/debug"
"sync"
"sync/atomic"
Expand Down Expand Up @@ -477,6 +478,26 @@ func TestStatisticsLossy(t *testing.T) {
}
}

func TestSetIfaceName(t *testing.T) {
pinger := New("localhost")
pinger.Count = 1
pinger.Timeout = time.Second

// Set loopback interface
pinger.Iface = "lo"
err := pinger.Run()
if runtime.GOOS == "linux" {
AssertNoError(t, err)
} else {
AssertError(t, err, "other platforms unsupport this feature")
}

// Set fake interface
pinger.Iface = "L()0pB@cK"
err = pinger.Run()
AssertError(t, err, "device not found")
}

// Test helpers
func makeTestPinger() *Pinger {
pinger := New("127.0.0.1")
Expand Down Expand Up @@ -643,6 +664,7 @@ func (c testPacketConn) ICMPRequestType() icmp.Type { return ipv4.ICMPTyp
func (c testPacketConn) SetFlagTTL() error { return nil }
func (c testPacketConn) SetReadDeadline(t time.Time) error { return nil }
func (c testPacketConn) SetTTL(t int) {}
func (c testPacketConn) SetIfaceIndex(t int) {}

func (c testPacketConn) ReadFrom(b []byte) (n int, ttl int, src net.Addr, err error) {
return 0, 0, nil, nil
Expand Down

0 comments on commit c58280f

Please sign in to comment.