diff --git a/.circleci/config.yml b/.circleci/config.yml index a49fc38..3004057 100644 --- a/.circleci/config.yml +++ b/.circleci/config.yml @@ -13,8 +13,8 @@ jobs: use_gomod_cache: type: boolean default: true - docker: - - image: cimg/go:<< parameters.go_version >> + machine: + image: ubuntu-2204:2024.05.1 steps: - checkout - when: diff --git a/cmd/ping/ping.go b/cmd/ping/ping.go index 645f2e2..acb5b89 100644 --- a/cmd/ping/ping.go +++ b/cmd/ping/ping.go @@ -13,7 +13,7 @@ import ( var usage = ` Usage: - ping [-c count] [-i interval] [-t timeout] [--privileged] host + ping [-c count] [-i interval] [-t timeout] [-I interface] [--privileged] host Examples: @@ -29,6 +29,9 @@ Examples: # ping google for 10 seconds ping -t 10s www.google.com + # ping google specified interface + ping -I eth1 www.goole.com + # Send a privileged raw ICMP ping sudo ping --privileged www.google.com @@ -42,6 +45,7 @@ func main() { count := flag.Int("c", -1, "") size := flag.Int("s", 24, "") ttl := flag.Int("l", 64, "TTL") + iface := flag.String("I", "", "interface name") privileged := flag.Bool("privileged", false, "") flag.Usage = func() { fmt.Print(usage) @@ -90,6 +94,7 @@ func main() { pinger.Interval = *interval pinger.Timeout = *timeout pinger.TTL = *ttl + pinger.InterfaceName = *iface pinger.SetPrivileged(*privileged) fmt.Printf("PING %s (%s):\n", pinger.Addr(), pinger.IPAddr()) diff --git a/packetconn.go b/packetconn.go index c4ca820..8528d79 100644 --- a/packetconn.go +++ b/packetconn.go @@ -21,11 +21,13 @@ type packetConn interface { SetMark(m uint) error SetDoNotFragment() error SetBroadcastFlag() error + SetIfIndex(ifIndex int) } type icmpConn struct { - c *icmp.PacketConn - ttl int + c *icmp.PacketConn + ttl int + ifIndex int } func (c *icmpConn) Close() error { @@ -36,23 +38,12 @@ func (c *icmpConn) SetTTL(ttl int) { c.ttl = ttl } -func (c *icmpConn) SetReadDeadline(t time.Time) error { - return c.c.SetReadDeadline(t) +func (c *icmpConn) SetIfIndex(ifIndex int) { + c.ifIndex = ifIndex } -func (c *icmpConn) WriteTo(b []byte, dst net.Addr) (int, error) { - if c.c.IPv6PacketConn() != nil { - if err := c.c.IPv6PacketConn().SetHopLimit(c.ttl); err != nil { - return 0, err - } - } - if c.c.IPv4PacketConn() != nil { - if err := c.c.IPv4PacketConn().SetTTL(c.ttl); err != nil { - return 0, err - } - } - - return c.c.WriteTo(b, dst) +func (c *icmpConn) SetReadDeadline(t time.Time) error { + return c.c.SetReadDeadline(t) } type icmpv4Conn struct { @@ -76,6 +67,22 @@ func (c *icmpv4Conn) ReadFrom(b []byte) (int, int, net.Addr, error) { return n, ttl, src, err } +func (c *icmpv4Conn) WriteTo(b []byte, dst net.Addr) (int, error) { + if err := c.c.IPv4PacketConn().SetTTL(c.ttl); err != nil { + return 0, err + } + var cm *ipv4.ControlMessage + if 1 <= c.ifIndex { + // c.ifIndex == 0 if not set interface + if err := c.c.IPv4PacketConn().SetControlMessage(ipv4.FlagInterface, true); err != nil { + return 0, err + } + cm = &ipv4.ControlMessage{IfIndex: c.ifIndex} + } + + return c.c.IPv4PacketConn().WriteTo(b, cm, dst) +} + func (c icmpv4Conn) ICMPRequestType() icmp.Type { return ipv4.ICMPTypeEcho } @@ -101,6 +108,22 @@ func (c *icmpV6Conn) ReadFrom(b []byte) (int, int, net.Addr, error) { return n, ttl, src, err } +func (c *icmpV6Conn) WriteTo(b []byte, dst net.Addr) (int, error) { + if err := c.c.IPv6PacketConn().SetHopLimit(c.ttl); err != nil { + return 0, err + } + var cm *ipv6.ControlMessage + if 1 <= c.ifIndex { + // c.ifIndex == 0 if not set interface + if err := c.c.IPv6PacketConn().SetControlMessage(ipv6.FlagInterface, true); err != nil { + return 0, err + } + cm = &ipv6.ControlMessage{IfIndex: c.ifIndex} + } + + return c.c.IPv6PacketConn().WriteTo(b, cm, dst) +} + func (c icmpV6Conn) ICMPRequestType() icmp.Type { return ipv6.ICMPTypeEchoRequest } diff --git a/ping.go b/ping.go index 9175192..b6de420 100644 --- a/ping.go +++ b/ping.go @@ -207,6 +207,9 @@ type Pinger struct { // Source is the source IP address Source string + // Interface used to send/recv ICMP messages + InterfaceName string + // Channel and mutex used to communicate when the Pinger should stop between goroutines. done chan interface{} lock sync.Mutex @@ -525,6 +528,13 @@ func (p *Pinger) RunWithContext(ctx context.Context) error { } conn.SetTTL(p.TTL) + if p.InterfaceName != "" { + iface, err := net.InterfaceByName(p.InterfaceName) + if err != nil { + return err + } + conn.SetIfIndex(iface.Index) + } return p.run(ctx, conn) } diff --git a/ping_test.go b/ping_test.go index bbbfe9b..d26bd82 100644 --- a/ping_test.go +++ b/ping_test.go @@ -5,6 +5,7 @@ import ( "context" "errors" "net" + "runtime" "runtime/debug" "sync" "sync/atomic" @@ -473,6 +474,26 @@ func TestStatisticsZeroDivision(t *testing.T) { } } +func TestSetInterfaceName(t *testing.T) { + pinger := New("localhost") + pinger.Count = 1 + pinger.Timeout = time.Second + + // Set loopback interface + pinger.InterfaceName = "lo" + err := pinger.Run() + if runtime.GOOS == "linux" { + AssertNoError(t, err) + } else { + AssertError(t, err, "other platforms unsupport this feature") + } + + // Set fake interface + pinger.InterfaceName = "L()0pB@cK" + err = pinger.Run() + AssertError(t, err, "device not found") +} + // Test helpers func makeTestPinger() *Pinger { pinger := New("127.0.0.1") @@ -644,7 +665,7 @@ func (c testPacketConn) SetTTL(t int) {} func (c testPacketConn) SetMark(m uint) error { return nil } func (c testPacketConn) SetDoNotFragment() error { return nil } func (c testPacketConn) SetBroadcastFlag() error { return nil } - +func (c testPacketConn) SetIfIndex(ifIndex int) {} func (c testPacketConn) ReadFrom(b []byte) (n int, ttl int, src net.Addr, err error) { return 0, 0, testAddr, nil }