diff --git a/Makefile b/Makefile index 4a0e5fb..909a418 100644 --- a/Makefile +++ b/Makefile @@ -11,6 +11,6 @@ build: build-static: GOOS=linux GOARCH=amd64 go build -ldflags="-s -w" -o $(BUILD_DIR)/$(MAIN_PROGRAM_NAME)-linux-static main.go - + upx $(BUILD_DIR)/$(MAIN_PROGRAM_NAME)-linux-static clean: rm -rf $(BUILD_DIR) \ No newline at end of file diff --git a/cmd/all/all.go b/cmd/all/all.go index e280341..af23219 100644 --- a/cmd/all/all.go +++ b/cmd/all/all.go @@ -1,11 +1,14 @@ package all import ( + "net" "os" command "github.com/esonhugh/k8spider/cmd" "github.com/esonhugh/k8spider/define" "github.com/esonhugh/k8spider/pkg" + "github.com/esonhugh/k8spider/pkg/mutli" + "github.com/esonhugh/k8spider/pkg/scanner" "github.com/miekg/dns" log "github.com/sirupsen/logrus" "github.com/spf13/cobra" @@ -24,23 +27,43 @@ var AllCmd = &cobra.Command{ log.Warn("cidr is required") return } + records, err := scanner.DumpAXFR(dns.Fqdn(command.Opts.Zone), "ns.dns."+command.Opts.Zone+":53") + if err == nil { + printResult(records) + } + log.Errorf("Transfer failed: %v", err) ipNets, err := pkg.ParseStringToIPNet(command.Opts.Cidr) if err != nil { log.Warnf("ParseStringToIPNet failed: %v", err) return } - var records define.Records = pkg.ScanSubnet(ipNets) - if records == nil || len(records) == 0 { - log.Warnf("ScanSubnet Found Nothing: %v", err) - return + if command.Opts.BatchMode { + RunBatch(ipNets) + } else { + Run(ipNets) } - records = pkg.ScanSvcForPorts(records) - printResult(records) - records = pkg.DumpAXFR(dns.Fqdn(command.Opts.Zone), "ns.dns."+command.Opts.Zone+":53") - printResult(records) }, } +func Run(net *net.IPNet) { + var records define.Records = scanner.ScanSubnet(net) + if records == nil || len(records) == 0 { + log.Warnf("ScanSubnet Found Nothing") + return + } + records = scanner.ScanSvcForPorts(records) + printResult(records) +} + +func RunBatch(net *net.IPNet) { + scan := mutli.ScanAll(net) + var finalRecord []define.Record + for r := range scan { + finalRecord = append(finalRecord, r...) + } + printResult(finalRecord) +} + func printResult(records define.Records) { if command.Opts.OutputFile != "" { f, err := os.OpenFile(command.Opts.OutputFile, os.O_CREATE|os.O_WRONLY, 0644) diff --git a/cmd/axfr/axfr.go b/cmd/axfr/axfr.go index af58591..7539a36 100644 --- a/cmd/axfr/axfr.go +++ b/cmd/axfr/axfr.go @@ -6,7 +6,7 @@ import ( command "github.com/esonhugh/k8spider/cmd" "github.com/esonhugh/k8spider/define" - "github.com/esonhugh/k8spider/pkg" + "github.com/esonhugh/k8spider/pkg/scanner" "github.com/miekg/dns" log "github.com/sirupsen/logrus" "github.com/spf13/cobra" @@ -35,7 +35,12 @@ var AxfrCmd = &cobra.Command{ } log.Debugf("same command: dig axfr %v @%v", zone, dnsServer) - var records define.Records = pkg.DumpAXFR(zone, dnsServer) + var records define.Records + records, err := scanner.DumpAXFR(zone, dnsServer) + if err != nil { + log.Errorf("Transfer failed: %v", err) + return + } if command.Opts.OutputFile != "" { f, err := os.OpenFile(command.Opts.OutputFile, os.O_CREATE|os.O_WRONLY, 0644) if err != nil { diff --git a/cmd/root.go b/cmd/root.go index 13d6a42..3de8515 100644 --- a/cmd/root.go +++ b/cmd/root.go @@ -17,7 +17,9 @@ var Opts = struct { SvcDomains []string Zone string OutputFile string - LogLevel string + Verbose string + + BatchMode bool }{} func init() { @@ -26,7 +28,8 @@ func init() { RootCmd.PersistentFlags().StringSliceVarP(&Opts.SvcDomains, "svc-domains", "s", []string{}, "service domains, like: kubernetes.default,etcd.default don't add zone like svc.cluster.local") RootCmd.PersistentFlags().StringVarP(&Opts.Zone, "zone", "z", "cluster.local", "zone") RootCmd.PersistentFlags().StringVarP(&Opts.OutputFile, "output-file", "o", "", "output file") - RootCmd.PersistentFlags().StringVarP(&Opts.LogLevel, "log-level", "l", "info", "log level") + RootCmd.PersistentFlags().StringVarP(&Opts.Verbose, "verbose", "v", "info", "log level (debug,info,trace,warn,error,fatal,panic)") + RootCmd.PersistentFlags().BoolVarP(&Opts.BatchMode, "batch-mode", "b", false, "batch mode") } var RootCmd = &cobra.Command{ @@ -34,7 +37,7 @@ var RootCmd = &cobra.Command{ Short: "k8spider is a tool to discover k8s services", Long: "k8spider is a tool to discover k8s services", PersistentPreRun: func(cmd *cobra.Command, args []string) { - SetLogLevel(Opts.LogLevel) + SetLogLevel(Opts.Verbose) if Opts.DnsServer != "" { pkg.NetResolver = &net.Resolver{ PreferGo: true, diff --git a/cmd/service/service.go b/cmd/service/service.go index e3c561a..b5b6196 100644 --- a/cmd/service/service.go +++ b/cmd/service/service.go @@ -6,7 +6,7 @@ import ( command "github.com/esonhugh/k8spider/cmd" "github.com/esonhugh/k8spider/define" - "github.com/esonhugh/k8spider/pkg" + "github.com/esonhugh/k8spider/pkg/scanner" log "github.com/sirupsen/logrus" "github.com/spf13/cobra" ) @@ -27,7 +27,7 @@ var ServiceCmd = &cobra.Command{ for _, domain := range command.Opts.SvcDomains { records = append(records, define.Record{SvcDomain: fmt.Sprintf("%s.svc.%s", domain, command.Opts.Zone)}) } - records = pkg.ScanSvcForPorts(records) + records = scanner.ScanSvcForPorts(records) if command.Opts.OutputFile != "" { f, err := os.OpenFile(command.Opts.OutputFile, os.O_CREATE|os.O_WRONLY, 0644) if err != nil { diff --git a/cmd/subnet/subnet.go b/cmd/subnet/subnet.go index 8d7ddd0..4e9064d 100644 --- a/cmd/subnet/subnet.go +++ b/cmd/subnet/subnet.go @@ -1,11 +1,14 @@ package subnet import ( + "net" "os" command "github.com/esonhugh/k8spider/cmd" "github.com/esonhugh/k8spider/define" "github.com/esonhugh/k8spider/pkg" + "github.com/esonhugh/k8spider/pkg/mutli" + "github.com/esonhugh/k8spider/pkg/scanner" log "github.com/sirupsen/logrus" "github.com/spf13/cobra" ) @@ -27,20 +30,45 @@ var SubNetCmd = &cobra.Command{ log.Warnf("ParseStringToIPNet failed: %v", err) return } - var records define.Records = pkg.ScanSubnet(ipNets) - if records == nil || len(records) == 0 { - log.Warnf("ScanSubnet Found Nothing: %v", err) - return - } - if command.Opts.OutputFile != "" { - f, err := os.OpenFile(command.Opts.OutputFile, os.O_CREATE|os.O_WRONLY, 0644) - if err != nil { - log.Warnf("OpenFile failed: %v", err) - } - defer f.Close() - records.Print(log.StandardLogger().Writer(), f) + if command.Opts.BatchMode { + BatchRun(ipNets) } else { - records.Print(log.StandardLogger().Writer()) + Run(ipNets) } }, } + +func Run(net *net.IPNet) { + var records define.Records = scanner.ScanSubnet(net) + if records == nil || len(records) == 0 { + log.Warnf("ScanSubnet Found Nothing") + return + } + printResult(records) +} + +func BatchRun(net *net.IPNet) { + scan := mutli.NewSubnetScanner() + var finalRecord []define.Record + for r := range scan.ScanSubnet(net) { + finalRecord = append(finalRecord, r...) + } + if len(finalRecord) == 0 { + log.Warn("ScanSubnet Found Nothing") + return + } + printResult(finalRecord) +} + +func printResult(records define.Records) { + if command.Opts.OutputFile != "" { + f, err := os.OpenFile(command.Opts.OutputFile, os.O_CREATE|os.O_WRONLY, 0644) + if err != nil { + log.Warnf("OpenFile failed: %v", err) + } + defer f.Close() + records.Print(log.StandardLogger().Writer(), f) + } else { + records.Print(log.StandardLogger().Writer()) + } +} diff --git a/pkg/mutli/executor.go b/pkg/mutli/executor.go new file mode 100644 index 0000000..082faf9 --- /dev/null +++ b/pkg/mutli/executor.go @@ -0,0 +1,13 @@ +package mutli + +import ( + "net" + + "github.com/esonhugh/k8spider/define" +) + +func ScanAll(subnet *net.IPNet) (result <-chan []define.Record) { + subs := NewSubnetScanner() + result = ScanServiceWithChan(subs.ScanSubnet(subnet)) + return result +} diff --git a/pkg/mutli/subnet.go b/pkg/mutli/subnet.go new file mode 100644 index 0000000..3e3cc0a --- /dev/null +++ b/pkg/mutli/subnet.go @@ -0,0 +1,53 @@ +package mutli + +import ( + "net" + "sync" + "time" + + "github.com/esonhugh/k8spider/define" + "github.com/esonhugh/k8spider/pkg" + "github.com/esonhugh/k8spider/pkg/scanner" + log "github.com/sirupsen/logrus" +) + +type SubnetScanner struct { + wg *sync.WaitGroup +} + +func NewSubnetScanner() *SubnetScanner { + return &SubnetScanner{ + wg: new(sync.WaitGroup), + } +} + +func (s *SubnetScanner) ScanSubnet(subnet *net.IPNet) <-chan []define.Record { + if subnet == nil { + log.Debugf("subnet is nil") + return nil + } + out := make(chan []define.Record, 100) + go func() { + log.Debugf("splitting subnet into 16 pices") + if subnets, err := pkg.SubnetShift(subnet, 4); err != nil { + go s.scan(subnet, out) + } else { + for _, sn := range subnets { + go s.scan(sn, out) + } + } + time.Sleep(10 * time.Millisecond) // wait for all goroutines to start + s.wg.Wait() + close(out) + }() + return out +} + +func (s *SubnetScanner) scan(subnet *net.IPNet, to chan []define.Record) { + s.wg.Add(1) + // to <- scanner.ScanSubnet(subnet) + for _, ip := range pkg.ParseIPNetToIPs(subnet) { + to <- scanner.ScanSingleIP(ip) + } + s.wg.Done() +} diff --git a/pkg/mutli/svc.go b/pkg/mutli/svc.go new file mode 100644 index 0000000..e409544 --- /dev/null +++ b/pkg/mutli/svc.go @@ -0,0 +1,17 @@ +package mutli + +import ( + "github.com/esonhugh/k8spider/define" + "github.com/esonhugh/k8spider/pkg/scanner" +) + +func ScanServiceWithChan(rev <-chan []define.Record) <-chan []define.Record { + out := make(chan []define.Record, 100) + go func() { + for records := range rev { + out <- scanner.ScanSvcForPorts(records) + } + close(out) + }() + return out +} diff --git a/pkg/scanner.go b/pkg/scanner/scanner.go similarity index 55% rename from pkg/scanner.go rename to pkg/scanner/scanner.go index b2c1b8a..79ab82b 100644 --- a/pkg/scanner.go +++ b/pkg/scanner/scanner.go @@ -1,17 +1,30 @@ -package pkg +package scanner import ( "net" "strings" "github.com/esonhugh/k8spider/define" + "github.com/esonhugh/k8spider/pkg" "github.com/miekg/dns" log "github.com/sirupsen/logrus" ) +func ScanSingleIP(subnet net.IP) (records []define.Record) { + ptr := pkg.PTRRecord(subnet) + if len(ptr) > 0 { + for _, domain := range ptr { + log.Infof("PTRrecord %v --> %v", subnet, domain) + r := define.Record{Ip: subnet, SvcDomain: domain} + records = append(records, r) + } + } + return +} + func ScanSubnet(subnet *net.IPNet) (records []define.Record) { - for _, ip := range ParseIPNetToIPs(subnet) { - ptr := PTRRecord(ip) + for _, ip := range pkg.ParseIPNetToIPs(subnet) { + ptr := pkg.PTRRecord(ip) if len(ptr) > 0 { for _, domain := range ptr { log.Infof("PTRrecord %v --> %v", ip, domain) @@ -25,9 +38,22 @@ func ScanSubnet(subnet *net.IPNet) (records []define.Record) { return } +func ScanSingleSvcForPorts(records define.Record) define.Record { + cname, srv, err := pkg.SRVRecord(records.SvcDomain) + if err != nil { + log.Debugf("SRVRecord for %v,failed: %v", records.SvcDomain, err) + return records + } + for _, s := range srv { + log.Infof("SRVRecord: %v --> %v:%v", records.SvcDomain, s.Target, s.Port) + } + records.SetSrvRecord(cname, srv) + return records +} + func ScanSvcForPorts(records []define.Record) []define.Record { for i, r := range records { - cname, srv, err := SRVRecord(r.SvcDomain) + cname, srv, err := pkg.SRVRecord(r.SvcDomain) if err != nil { log.Debugf("SRVRecord for %v,failed: %v", r.SvcDomain, err) continue @@ -41,18 +67,18 @@ func ScanSvcForPorts(records []define.Record) []define.Record { } // default target should be zone -func DumpAXFR(target string, dnsServer string) []define.Record { +func DumpAXFR(target string, dnsServer string) ([]define.Record, error) { t := new(dns.Transfer) m := new(dns.Msg) m.SetAxfr(target) ch, err := t.In(m, dnsServer) if err != nil { - log.Fatalf("Transfer failed: %v", err) + return nil, err } var records []define.Record for rr := range ch { if rr.Error != nil { - log.Errorf("Error: %v", rr.Error) + log.Debugf("Error: %v", rr.Error) continue } for _, r := range rr.RR { @@ -63,5 +89,5 @@ func DumpAXFR(target string, dnsServer string) []define.Record { } log.Debugf("Record: %v", rr.RR) } - return records + return records, nil } diff --git a/pkg/subnets.go b/pkg/subnets.go new file mode 100644 index 0000000..26b5a89 --- /dev/null +++ b/pkg/subnets.go @@ -0,0 +1,91 @@ +package pkg + +import ( + "fmt" + "math" + "net" +) + +// SubnetInto wraps SubnetShift and divides a network into at least count-many, +// equal-sized subnets, which are as large as allowed. +func SubnetInto(network *net.IPNet, count int) ([]*net.IPNet, error) { + maskBits, _ := network.Mask.Size() + hostBits := 32 - maskBits + hostCount := 1 << uint(hostBits) + + // divide hosts among subnets + ideal := float64(hostCount) / float64(count) + // largest power of 2, not exceeding the ideal (float64 to int conversion + // truncates toward zero) + newHostBits := int(math.Log2(ideal)) + shift := hostBits - newHostBits + return SubnetShift(network, shift) +} + +// SubnetShift divides a network into subnets by shifting the given number of bits. +func SubnetShift(network *net.IPNet, bits int) ([]*net.IPNet, error) { + if bits < 0 { + return nil, fmt.Errorf("bit shift may not be negative, got %d", bits) + } + if bits > 31 { + return nil, fmt.Errorf("network subnets cannot be divided %d times", bits) + } + // network divides into 2^bits subnets + subnetCount := 1 << uint(bits) + subnets := make([]*net.IPNet, subnetCount) + + // network info + start := network.IP + maskBits, _ := network.Mask.Size() + hostBits := 32 - maskBits + + if maskBits+bits > 32 { + return nil, fmt.Errorf("network subnet mask greater than /32, /%d is invalid", maskBits+bits) + } + + // divide network into subnets + newMaskBits := maskBits + bits + newHostBits := hostBits - bits + // subnet bitmasks are shifted by 'bits' places + newMask := net.CIDRMask(newMaskBits, 32) + + // hosts per subnet + hostCount := 1 << uint(newHostBits) + + for i := 0; i < subnetCount; i++ { + ip := numeric(start) + uint32(i*hostCount) + subnets[i] = &net.IPNet{ + IP: bytewise(ip), + Mask: newMask, + } + } + + return subnets, nil +} + +// IP <-> integer transforms + +// numeric returns a uint32 numeric representation of a net.IP. +func numeric(bytes net.IP) uint32 { + var ip uint32 + // most significant to least significant + for i, b := range []byte(bytes) { + // bitwise or ("append" in this case) + ip |= uint32(b) << (8 * uint32(3-i)) + } + return ip +} + +// bytewise returns a net.IP byte slice alias representation of an uint32. +// Note that not all uint32 values are valid IP addresses. +func bytewise(numeric uint32) net.IP { + ip := make([]byte, 4) + // least significant to most significant + for i := 3; i >= 0; i-- { + // AND away all but least significant + ip[i] = byte(numeric & 0xFF) + // nuke least significant byte + numeric >>= 8 + } + return net.IP(ip) +}