-
Notifications
You must be signed in to change notification settings - Fork 0
/
main.go
148 lines (124 loc) · 3.11 KB
/
main.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
package main
import (
"bufio"
"fmt"
"log"
"net"
"net/http"
_ "net/http/pprof"
"sync"
"time"
)
func main() {
go func() {
log.Println(http.ListenAndServe("localhost:6060", nil))
}()
dnsCache := NewDnsCache(60)
addr := net.UDPAddr{
Port: 1053,
IP: net.ParseIP("127.0.0.1"),
}
fmt.Println("listening on udp")
ln, err := net.ListenUDP("udp", &addr)
if err != nil {
log.Fatal(err)
return
}
upstream, err := net.Dial("udp", "8.8.8.8:53")
if err != nil {
fmt.Printf("Some error %v", err)
return
}
var wg sync.WaitGroup
for {
var buf [512]uint8
fmt.Println("waiting to read from udp")
_, addr, err := ln.ReadFromUDP(buf[:])
if err != nil {
fmt.Println("cant read from udp??")
log.Fatal(err)
}
fmt.Println("read something.")
fmt.Println("Starting worker")
wg.Add(1)
go worker(&wg, true, ln, addr, buf, &upstream, dnsCache)
}
// wg.Wait()
}
func worker(wg *sync.WaitGroup, debug bool, conn *net.UDPConn, addr *net.UDPAddr,
buf [512]uint8, upstream *net.Conn, cache *DnsCache) {
defer wg.Done()
recBuf := NewPacket()
recBuf.buf = buf
recPacket := FromBuffer(&recBuf)
fmt.Printf("Decoded the packet. making another.\n")
packet := NewDnsPacket()
packet.header.id = recPacket.header.id
packet.header.questions = recPacket.header.questions
packet.header.recursionDesired = true
packet.header.z = false
fmt.Printf("checking cache\n")
flag := 0
for _, q := range recPacket.questions {
cache.mutex.Lock()
if cache.cache[q.name] != nil && cache.cache[q.name].value != nil {
packet.questions = append(packet.questions, NewDnsQuestion(q.name, A))
flag = 1
}
cache.mutex.Unlock()
}
fmt.Printf("not present in the cache.\n")
if flag == 1 {
resBuf := NewPacket()
packet.toBuffer(&resBuf)
_, err := conn.WriteToUDP(resBuf.buf[0:resBuf.pos], addr)
if err != nil {
fmt.Printf("Some error while writing back to client%v", err)
return
}
return
}
num, err := (*upstream).Write(recBuf.buf[0:recBuf.pos])
if err != nil {
fmt.Printf("Some error while writing%v", err)
return
}
fmt.Printf("Wrote %#v bytes\n", num)
resBuf := NewPacket()
num, err = bufio.NewReader(*upstream).Read(resBuf.buf[:])
if err != nil {
fmt.Printf("Some error while reading %v", err)
return
}
fmt.Printf("Read %#v bytes\n", num)
resPacket := FromBuffer(&resBuf)
for _, a := range resPacket.answers {
if domain := a.ARecord.domain; domain != "" {
cache.mutex.Lock()
cache.cache[domain] = &item{}
cache.cache[domain].value = a.ARecord.addr
cache.cache[domain].lastAccess = time.Now().Unix()
cache.mutex.Unlock()
}
}
_, err = conn.WriteToUDP(resBuf.buf[0:resBuf.pos], addr)
if err != nil {
fmt.Printf("Some error while writing back to client 2 %v", err)
return
}
if debug {
for _, q := range resPacket.questions {
fmt.Printf("\nquestions: %#v\n", q)
}
fmt.Println("num answers", len(resPacket.answers))
for _, a := range resPacket.answers {
fmt.Printf("\nanswers: %#v\n", a)
}
for _, au := range resPacket.authorities {
fmt.Println("authorities : ", au)
}
for _, r := range resPacket.resources {
fmt.Println("resources : ", r)
}
}
}