From c58280fb3c65869d9a4c7fb2d0d20bbc476d1270 Mon Sep 17 00:00:00 2001 From: ilolicon <97431110@qq.com> Date: Wed, 12 Apr 2023 10:39:36 +0800 Subject: [PATCH] feat: interface binding --- cmd/ping/ping.go | 7 +++++- packetconn.go | 57 +++++++++++++++++++++++++++++++++--------------- ping.go | 10 +++++++++ ping_test.go | 22 +++++++++++++++++++ 4 files changed, 78 insertions(+), 18 deletions(-) diff --git a/cmd/ping/ping.go b/cmd/ping/ping.go index 645f2e2..aae6b98 100644 --- a/cmd/ping/ping.go +++ b/cmd/ping/ping.go @@ -13,7 +13,7 @@ import ( var usage = ` Usage: - ping [-c count] [-i interval] [-t timeout] [--privileged] host + ping [-c count] [-i interval] [-t timeout] [-I interface] [--privileged] host Examples: @@ -29,6 +29,9 @@ Examples: # ping google for 10 seconds ping -t 10s www.google.com + # ping google specified interface + ping -I eth1 www.goole.com + # Send a privileged raw ICMP ping sudo ping --privileged www.google.com @@ -42,6 +45,7 @@ func main() { count := flag.Int("c", -1, "") size := flag.Int("s", 24, "") ttl := flag.Int("l", 64, "TTL") + iface := flag.String("I", "", "interface name") privileged := flag.Bool("privileged", false, "") flag.Usage = func() { fmt.Print(usage) @@ -90,6 +94,7 @@ func main() { pinger.Interval = *interval pinger.Timeout = *timeout pinger.TTL = *ttl + pinger.Iface = *iface pinger.SetPrivileged(*privileged) fmt.Printf("PING %s (%s):\n", pinger.Addr(), pinger.IPAddr()) diff --git a/packetconn.go b/packetconn.go index 4e469ba..80336a8 100644 --- a/packetconn.go +++ b/packetconn.go @@ -18,11 +18,13 @@ type packetConn interface { SetReadDeadline(t time.Time) error WriteTo(b []byte, dst net.Addr) (int, error) SetTTL(ttl int) + SetIfaceIndex(ifaceIndex int) } type icmpConn struct { - c *icmp.PacketConn - ttl int + c *icmp.PacketConn + ttl int + ifaceIndex int } func (c *icmpConn) Close() error { @@ -33,23 +35,12 @@ func (c *icmpConn) SetTTL(ttl int) { c.ttl = ttl } -func (c *icmpConn) SetReadDeadline(t time.Time) error { - return c.c.SetReadDeadline(t) +func (c *icmpConn) SetIfaceIndex(ifaceIndex int) { + c.ifaceIndex = ifaceIndex } -func (c *icmpConn) WriteTo(b []byte, dst net.Addr) (int, error) { - if c.c.IPv6PacketConn() != nil { - if err := c.c.IPv6PacketConn().SetHopLimit(c.ttl); err != nil { - return 0, err - } - } - if c.c.IPv4PacketConn() != nil { - if err := c.c.IPv4PacketConn().SetTTL(c.ttl); err != nil { - return 0, err - } - } - - return c.c.WriteTo(b, dst) +func (c *icmpConn) SetReadDeadline(t time.Time) error { + return c.c.SetReadDeadline(t) } type icmpv4Conn struct { @@ -73,6 +64,22 @@ func (c *icmpv4Conn) ReadFrom(b []byte) (int, int, net.Addr, error) { return n, ttl, src, err } +func (c *icmpv4Conn) WriteTo(b []byte, dst net.Addr) (int, error) { + if err := c.c.IPv4PacketConn().SetTTL(c.ttl); err != nil { + return 0, err + } + var cm *ipv4.ControlMessage + if 1 <= c.ifaceIndex { + // c.ifaceIndex == 0 if not set interface + if err := c.c.IPv4PacketConn().SetControlMessage(ipv4.FlagInterface, true); err != nil { + return 0, err + } + cm = &ipv4.ControlMessage{IfIndex: c.ifaceIndex} + } + + return c.c.IPv4PacketConn().WriteTo(b, cm, dst) +} + func (c icmpv4Conn) ICMPRequestType() icmp.Type { return ipv4.ICMPTypeEcho } @@ -98,6 +105,22 @@ func (c *icmpV6Conn) ReadFrom(b []byte) (int, int, net.Addr, error) { return n, ttl, src, err } +func (c *icmpV6Conn) WriteTo(b []byte, dst net.Addr) (int, error) { + if err := c.c.IPv6PacketConn().SetHopLimit(c.ttl); err != nil { + return 0, err + } + var cm *ipv6.ControlMessage + if 1 <= c.ifaceIndex { + // c.ifaceIndex == 0 if not set interface + if err := c.c.IPv6PacketConn().SetControlMessage(ipv6.FlagInterface, true); err != nil { + return 0, err + } + cm = &ipv6.ControlMessage{IfIndex: c.ifaceIndex} + } + + return c.c.IPv6PacketConn().WriteTo(b, cm, dst) +} + func (c icmpV6Conn) ICMPRequestType() icmp.Type { return ipv6.ICMPTypeEchoRequest } diff --git a/ping.go b/ping.go index caa4bd9..b2013ca 100644 --- a/ping.go +++ b/ping.go @@ -182,6 +182,9 @@ type Pinger struct { // Source is the source IP address Source string + // Iface used to send/recv ICMP messages + Iface string + // Channel and mutex used to communicate when the Pinger should stop between goroutines. done chan interface{} lock sync.Mutex @@ -426,6 +429,13 @@ func (p *Pinger) RunWithContext(ctx context.Context) error { defer conn.Close() conn.SetTTL(p.TTL) + if p.Iface != "" { + iface, err := net.InterfaceByName(p.Iface) + if err != nil { + return err + } + conn.SetIfaceIndex(iface.Index) + } return p.run(ctx, conn) } diff --git a/ping_test.go b/ping_test.go index c7baf7f..d0396e3 100644 --- a/ping_test.go +++ b/ping_test.go @@ -5,6 +5,7 @@ import ( "context" "errors" "net" + "runtime" "runtime/debug" "sync" "sync/atomic" @@ -477,6 +478,26 @@ func TestStatisticsLossy(t *testing.T) { } } +func TestSetIfaceName(t *testing.T) { + pinger := New("localhost") + pinger.Count = 1 + pinger.Timeout = time.Second + + // Set loopback interface + pinger.Iface = "lo" + err := pinger.Run() + if runtime.GOOS == "linux" { + AssertNoError(t, err) + } else { + AssertError(t, err, "other platforms unsupport this feature") + } + + // Set fake interface + pinger.Iface = "L()0pB@cK" + err = pinger.Run() + AssertError(t, err, "device not found") +} + // Test helpers func makeTestPinger() *Pinger { pinger := New("127.0.0.1") @@ -643,6 +664,7 @@ func (c testPacketConn) ICMPRequestType() icmp.Type { return ipv4.ICMPTyp func (c testPacketConn) SetFlagTTL() error { return nil } func (c testPacketConn) SetReadDeadline(t time.Time) error { return nil } func (c testPacketConn) SetTTL(t int) {} +func (c testPacketConn) SetIfaceIndex(t int) {} func (c testPacketConn) ReadFrom(b []byte) (n int, ttl int, src net.Addr, err error) { return 0, 0, nil, nil