1
1
package dnstun
2
2
3
3
import (
4
- "bytes"
5
4
"context"
6
- "encoding/json"
7
- "net"
8
- "net/http"
9
- "net/url"
10
- "path"
11
- "time"
12
5
13
6
"github.com/coredns/coredns/plugin"
14
7
"github.com/coredns/coredns/request"
15
8
"github.com/miekg/dns"
16
9
"github.com/pkg/errors"
17
- )
18
-
19
- var (
20
- // DefaultTransport is a default configuration of the Transport.
21
- DefaultTransport http.RoundTripper = & http.Transport {
22
- Proxy : http .ProxyFromEnvironment ,
23
- DialContext : (& net.Dialer {
24
- Timeout : 30 * time .Second ,
25
- KeepAlive : 30 * time .Second ,
26
- DualStack : true ,
27
- }).DialContext ,
28
- MaxIdleConns : 100 ,
29
- IdleConnTimeout : 90 * time .Second ,
30
- TLSHandshakeTimeout : 10 * time .Second ,
31
- ExpectContinueTimeout : 1 * time .Second ,
32
- }
33
-
34
- // DefaultClient is a default instance of the HTTP client.
35
- DefaultClient = & http.Client {
36
- Transport : DefaultTransport ,
37
- }
38
- )
39
-
40
- const (
41
- // MappingForward means that first element in the prediction tuple
42
- // is a probability of associating DNS query to the "good" domain
43
- // names. The second element is a probability of "bad" domain.
44
- MappingForward = "forward"
45
10
46
- // MappingReverse is reversed representation of probabilities in
47
- // the prediction tuple returned by the model.
48
- MappingReverse = "reverse"
11
+ tf "github.com/tensorflow/tensorflow/tensorflow/go"
12
+ tfop "github.com/tensorflow/tensorflow/tensorflow/go/op"
49
13
)
50
14
51
- // mappings lists all available mapping types.
52
- var mappings = map [string ]struct {}{
53
- MappingForward : struct {}{},
54
- MappingReverse : struct {}{},
55
- }
56
-
57
15
type Options struct {
58
- Mapping string
59
- Model string
60
- Version string
61
- Runtime string
16
+ Graph string
62
17
}
63
18
64
19
// Dnstun is a plugin to block DNS tunneling queries.
65
20
type Dnstun struct {
66
- opts Options
67
- client * http. Client
68
- tokenizer Tokenizer
21
+ predictGraph execGraph
22
+ argmaxGraph execGraph
23
+ tokenizer Tokenizer
69
24
}
70
25
71
26
// NewDnstun creates a new instance of the DNS tunneling detector plugin.
72
- func NewDnstun (opts Options ) * Dnstun {
27
+ func NewDnstun (predictGraph * tf. Graph ) * Dnstun {
73
28
return & Dnstun {
74
- opts : opts ,
75
- client : DefaultClient ,
76
- tokenizer : NewTokenizer (enUS , 256 ),
29
+ predictGraph : newExecGraph ( predictGraph ) ,
30
+ argmaxGraph : newExecGraph ( newArgmax ([] int64 { 1 , 2 }, tf . Float , 1 )) ,
31
+ tokenizer : NewTokenizer (enUS , 256 ),
77
32
}
78
33
}
79
34
80
35
func (d * Dnstun ) Name () string {
81
36
return "dnstun"
82
37
}
83
38
84
- func (d * Dnstun ) ServeDNS (ctx context.Context , w dns.ResponseWriter , r * dns.Msg ) (int , error ) {
85
- var (
86
- state = request.Request {W : w , Req : r }
87
- resp PredictResponse
88
- )
89
-
90
- req := PredictRequest {
91
- X : [][]int {d .tokenizer .TextToSeq (state .QName ())},
39
+ func (d * Dnstun ) predict (name string ) (int64 , error ) {
40
+ input , err := tf .NewTensor ([][]int64 {d .tokenizer .TextToSeq (name )})
41
+ if err != nil {
42
+ return - 1 , err
92
43
}
93
44
94
- p := path .Join ("/models" , d .opts .Model , d .opts .Version , "predict" )
95
-
96
- u := url.URL {Scheme : "http" , Host : d .opts .Runtime , Path : p }
97
- err := d .do (ctx , "POST" , & u , req , & resp )
45
+ output , err := d .predictGraph .Exec (input )
98
46
if err != nil {
99
- return dns . RcodeServerFailure , plugin . Error ( d . Name (), err )
47
+ return - 1 , err
100
48
}
101
-
102
- if len (resp .Y ) != 1 || len (resp .Y [0 ]) == 0 {
103
- err = errors .Errorf ("invalid predict response: %#v" , resp )
104
- return dns .RcodeServerFailure , plugin .Error (d .Name (), err )
49
+ if len (output ) == 0 {
50
+ return - 1 , errors .New ("prediction graph returned empty tensor" )
105
51
}
106
52
107
53
// Select max argument position from the response vector.
108
- var (
109
- yPos int = 0
110
- yMax float64 = resp .Y [0 ][yPos ]
111
- )
112
- for i := yPos + 1 ; i < len (resp .Y [0 ]); i ++ {
113
- if resp .Y [0 ][i ] > yMax {
114
- yPos = i
115
- yMax = resp .Y [0 ][i ]
116
- }
54
+ output , err = d .argmaxGraph .Exec (output [0 ])
55
+ if err != nil {
56
+ return - 1 , err
57
+ }
58
+ if len (output ) == 0 {
59
+ return - 1 , errors .New ("argmax returned empty tensor" )
60
+ }
61
+ index , ok := output [0 ].Value ().([]int64 )
62
+ if ! ok {
63
+ return - 1 , errors .Errorf ("unexpected output type %T" , output [0 ].Value ())
64
+ }
65
+ if len (index ) == 0 {
66
+ return - 1 , errors .New ("argmax return empty result" )
67
+ }
68
+ return index [0 ], nil
69
+ }
70
+
71
+ func (d * Dnstun ) ServeDNS (ctx context.Context , w dns.ResponseWriter , r * dns.Msg ) (int , error ) {
72
+ state := request.Request {W : w , Req : r }
73
+
74
+ category , err := d .predict (state .QName ())
75
+ if err != nil {
76
+ return dns .RcodeServerFailure , plugin .Error (d .Name (), err )
117
77
}
118
78
119
79
// The first position of the prediction vector corresponds to the DNS
120
80
// tunneling class, therefore such requests should be rejected.
121
- if (d .opts .Mapping == MappingForward && yPos == 1 ) ||
122
- (d .opts .Mapping == MappingReverse && yPos == 0 ) {
123
-
81
+ if category == 0 {
124
82
m := new (dns.Msg )
125
83
m .SetRcode (r , dns .RcodeRefused )
126
84
w .WriteMsg (m )
@@ -131,59 +89,6 @@ func (d *Dnstun) ServeDNS(ctx context.Context, w dns.ResponseWriter, r *dns.Msg)
131
89
return dns .RcodeSuccess , nil
132
90
}
133
91
134
- // PredictRequest is a request to get predictions for the given attribute vectors.
135
- type PredictRequest struct {
136
- X [][]int `json:"x"`
137
- }
138
-
139
- // PredictResponse lists probabilities for each attribute vector.
140
- type PredictResponse struct {
141
- Y [][]float64 `json:"y"`
142
- }
143
-
144
- func (d * Dnstun ) do (ctx context.Context , method string , u * url.URL , in , out interface {}) error {
145
- var (
146
- b []byte
147
- err error
148
- )
149
-
150
- if in != nil {
151
- b , err = json .Marshal (in )
152
- if err != nil {
153
- return errors .Wrapf (err , "failed to encode request" )
154
- }
155
- }
156
- req , err := http .NewRequest (method , u .String (), bytes .NewReader (b ))
157
- if err != nil {
158
- return err
159
- }
160
- resp , err := d .client .Do (req .WithContext (ctx ))
161
- if err != nil {
162
- return err
163
- }
164
-
165
- // Decode the list of nodes from the body of the response.
166
- defer resp .Body .Close ()
167
-
168
- // If server returned non-zero status, the response body is treated
169
- // as a error message, which will be returned to the user.
170
- if resp .StatusCode != http .StatusOK {
171
- // Server could return a response error within a header.
172
- errorCode := resp .Header .Get (http .CanonicalHeaderKey ("Error-Code" ))
173
- if errorCode != "" {
174
- return errors .New (errorCode )
175
- }
176
- return errors .Errorf ("unexpected response from server: %d" , resp .StatusCode )
177
- }
178
-
179
- if out == nil {
180
- return nil
181
- }
182
-
183
- decoder := json .NewDecoder (resp .Body )
184
- return errors .Wrapf (decoder .Decode (out ), "failed to decode response" )
185
- }
186
-
187
92
type chainHandler struct {
188
93
plugin.Handler
189
94
next plugin.Handler
@@ -204,3 +109,70 @@ func (p chainHandler) ServeDNS(ctx context.Context, w dns.ResponseWriter, r *dns
204
109
state := request.Request {W : w , Req : r }
205
110
return plugin .NextOrFailure (state .Name (), p .next , ctx , w , r )
206
111
}
112
+
113
+ func newArgmax (shape []int64 , dtype tf.DataType , dim int64 ) (graph * tf.Graph ) {
114
+ inShape := tf .MakeShape (shape ... )
115
+ root := tfop .NewScope ()
116
+
117
+ input := tfop .Placeholder (root , dtype , tfop .PlaceholderShape (inShape ))
118
+ tfop .ArgMax (root , input , tfop .Const (root , dim ))
119
+
120
+ graph , err := root .Finalize ()
121
+ if err != nil {
122
+ panic (err )
123
+ }
124
+ return graph
125
+ }
126
+
127
+ type execGraph struct {
128
+ graphInput tf.Output
129
+ graphOutput tf.Output
130
+ graph * tf.Graph
131
+ }
132
+
133
+ func newExecGraph (graph * tf.Graph ) execGraph {
134
+ var (
135
+ ops = graph .Operations ()
136
+ input tf.Output
137
+ )
138
+
139
+ for _ , o := range ops {
140
+ if o .Type () == "Placeholder" {
141
+ input = o .Output (0 )
142
+ break
143
+ }
144
+ }
145
+
146
+ if input == (tf.Output {}) {
147
+ panic ("graph without input" )
148
+ }
149
+ return execGraph {
150
+ graphInput : input ,
151
+ graphOutput : ops [len (ops )- 1 ].Output (0 ),
152
+ graph : graph ,
153
+ }
154
+ }
155
+
156
+ func (e execGraph ) Exec (in * tf.Tensor ) (output []* tf.Tensor , err error ) {
157
+ sess , err := tf .NewSession (e .graph , nil )
158
+ if err != nil {
159
+ return nil , err
160
+ }
161
+
162
+ defer func () {
163
+ e := sess .Close ()
164
+ if e != nil {
165
+ if err != nil {
166
+ err = errors .WithMessage (err , e .Error ())
167
+ } else {
168
+ err = e
169
+ }
170
+ }
171
+ }()
172
+
173
+ return sess .Run (
174
+ map [tf.Output ]* tf.Tensor {e .graphInput : in },
175
+ []tf.Output {e .graphOutput },
176
+ nil ,
177
+ )
178
+ }
0 commit comments