Skip to content

Commit 4615f91

Browse files
authored
Model inference (#8)
This patch provides in-process execution of the model using tensorflow.
1 parent a714ea7 commit 4615f91

File tree

10 files changed

+183
-219
lines changed

10 files changed

+183
-219
lines changed

.travis.yml

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,13 @@
33
language: go
44

55
env:
6-
- GO111MODULE=on
6+
global:
7+
- GO111MODULE=on
8+
- TENSORFLOW_SRC=https://storage.googleapis.com/tensorflow/libtensorflow/libtensorflow-cpu-linux-x86_64-1.15.0.tar.gz
9+
10+
before_install:
11+
- curl -fsSL $TENSORFLOW_SRC | sudo tar -C /usr/ -xzf -
12+
- sudo ldconfig
713

814
go:
915
- "1.12.x"

Corefile

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
. {
22
dnstun {
3-
runtime 127.0.0.1:5678
4-
detector reverse dns_cnn:latest
3+
graph /var/dnstun/dnscnn.pb
54
}
65
forward . 8.8.8.8
76

README.md

Lines changed: 4 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -13,33 +13,22 @@ tunnels.
1313

1414
```txt
1515
dnstun {
16-
runtime HOST:PORT
17-
detector forward|reverse DETECTOR:VERSION
16+
graph PATH
1817
}
1918
```
2019

21-
* `runtime` specifies the endpoint in `HOST:PORT` format to the remote model
22-
runtime. This runtime should comply with e.g. `tensorcraft` HTTP interface.
23-
24-
* `detector` is a directive to configure detector. Option `forward` instructs
25-
the plugin to treat higher probability in the second element of prediction tuple
26-
as DNS tunnel, while `reverse` tells that first element in the prediction tuple
27-
identifies DNS tunnel.
20+
* `graph` is a directive to configure detector. It is a path to the `.pb` file
21+
with constant graph used to classify DNS traffic.
2822

2923
## Examples
3024

3125
Here are the few basic examples of how to enable DNS tunnelling detection.
3226
Usually DNS tunneling detection is turned only for all DNS queries.
3327

34-
Analyze all DNS queries through remote resolver listening on TCP socket.
3528
```txt
3629
. {
3730
dnstun {
38-
# Connect to the runtime that stores model and executes it.
39-
runtime 10.240.0.1:5678
40-
41-
# Choose detector and it's version.
42-
detector reverse dns_cnn:latest
31+
graph /var/dnstun/dnscnn.pb
4332
}
4433
}
4534
```

dnstun.go

Lines changed: 109 additions & 137 deletions
Original file line numberDiff line numberDiff line change
@@ -1,126 +1,84 @@
11
package dnstun
22

33
import (
4-
"bytes"
54
"context"
6-
"encoding/json"
7-
"net"
8-
"net/http"
9-
"net/url"
10-
"path"
11-
"time"
125

136
"github.com/coredns/coredns/plugin"
147
"github.com/coredns/coredns/request"
158
"github.com/miekg/dns"
169
"github.com/pkg/errors"
17-
)
18-
19-
var (
20-
// DefaultTransport is a default configuration of the Transport.
21-
DefaultTransport http.RoundTripper = &http.Transport{
22-
Proxy: http.ProxyFromEnvironment,
23-
DialContext: (&net.Dialer{
24-
Timeout: 30 * time.Second,
25-
KeepAlive: 30 * time.Second,
26-
DualStack: true,
27-
}).DialContext,
28-
MaxIdleConns: 100,
29-
IdleConnTimeout: 90 * time.Second,
30-
TLSHandshakeTimeout: 10 * time.Second,
31-
ExpectContinueTimeout: 1 * time.Second,
32-
}
33-
34-
// DefaultClient is a default instance of the HTTP client.
35-
DefaultClient = &http.Client{
36-
Transport: DefaultTransport,
37-
}
38-
)
39-
40-
const (
41-
// MappingForward means that first element in the prediction tuple
42-
// is a probability of associating DNS query to the "good" domain
43-
// names. The second element is a probability of "bad" domain.
44-
MappingForward = "forward"
4510

46-
// MappingReverse is reversed representation of probabilities in
47-
// the prediction tuple returned by the model.
48-
MappingReverse = "reverse"
11+
tf "github.com/tensorflow/tensorflow/tensorflow/go"
12+
tfop "github.com/tensorflow/tensorflow/tensorflow/go/op"
4913
)
5014

51-
// mappings lists all available mapping types.
52-
var mappings = map[string]struct{}{
53-
MappingForward: struct{}{},
54-
MappingReverse: struct{}{},
55-
}
56-
5715
type Options struct {
58-
Mapping string
59-
Model string
60-
Version string
61-
Runtime string
16+
Graph string
6217
}
6318

6419
// Dnstun is a plugin to block DNS tunneling queries.
6520
type Dnstun struct {
66-
opts Options
67-
client *http.Client
68-
tokenizer Tokenizer
21+
predictGraph execGraph
22+
argmaxGraph execGraph
23+
tokenizer Tokenizer
6924
}
7025

7126
// NewDnstun creates a new instance of the DNS tunneling detector plugin.
72-
func NewDnstun(opts Options) *Dnstun {
27+
func NewDnstun(predictGraph *tf.Graph) *Dnstun {
7328
return &Dnstun{
74-
opts: opts,
75-
client: DefaultClient,
76-
tokenizer: NewTokenizer(enUS, 256),
29+
predictGraph: newExecGraph(predictGraph),
30+
argmaxGraph: newExecGraph(newArgmax([]int64{1, 2}, tf.Float, 1)),
31+
tokenizer: NewTokenizer(enUS, 256),
7732
}
7833
}
7934

8035
func (d *Dnstun) Name() string {
8136
return "dnstun"
8237
}
8338

84-
func (d *Dnstun) ServeDNS(ctx context.Context, w dns.ResponseWriter, r *dns.Msg) (int, error) {
85-
var (
86-
state = request.Request{W: w, Req: r}
87-
resp PredictResponse
88-
)
89-
90-
req := PredictRequest{
91-
X: [][]int{d.tokenizer.TextToSeq(state.QName())},
39+
func (d *Dnstun) predict(name string) (int64, error) {
40+
input, err := tf.NewTensor([][]int64{d.tokenizer.TextToSeq(name)})
41+
if err != nil {
42+
return -1, err
9243
}
9344

94-
p := path.Join("/models", d.opts.Model, d.opts.Version, "predict")
95-
96-
u := url.URL{Scheme: "http", Host: d.opts.Runtime, Path: p}
97-
err := d.do(ctx, "POST", &u, req, &resp)
45+
output, err := d.predictGraph.Exec(input)
9846
if err != nil {
99-
return dns.RcodeServerFailure, plugin.Error(d.Name(), err)
47+
return -1, err
10048
}
101-
102-
if len(resp.Y) != 1 || len(resp.Y[0]) == 0 {
103-
err = errors.Errorf("invalid predict response: %#v", resp)
104-
return dns.RcodeServerFailure, plugin.Error(d.Name(), err)
49+
if len(output) == 0 {
50+
return -1, errors.New("prediction graph returned empty tensor")
10551
}
10652

10753
// Select max argument position from the response vector.
108-
var (
109-
yPos int = 0
110-
yMax float64 = resp.Y[0][yPos]
111-
)
112-
for i := yPos + 1; i < len(resp.Y[0]); i++ {
113-
if resp.Y[0][i] > yMax {
114-
yPos = i
115-
yMax = resp.Y[0][i]
116-
}
54+
output, err = d.argmaxGraph.Exec(output[0])
55+
if err != nil {
56+
return -1, err
57+
}
58+
if len(output) == 0 {
59+
return -1, errors.New("argmax returned empty tensor")
60+
}
61+
index, ok := output[0].Value().([]int64)
62+
if !ok {
63+
return -1, errors.Errorf("unexpected output type %T", output[0].Value())
64+
}
65+
if len(index) == 0 {
66+
return -1, errors.New("argmax return empty result")
67+
}
68+
return index[0], nil
69+
}
70+
71+
func (d *Dnstun) ServeDNS(ctx context.Context, w dns.ResponseWriter, r *dns.Msg) (int, error) {
72+
state := request.Request{W: w, Req: r}
73+
74+
category, err := d.predict(state.QName())
75+
if err != nil {
76+
return dns.RcodeServerFailure, plugin.Error(d.Name(), err)
11777
}
11878

11979
// The first position of the prediction vector corresponds to the DNS
12080
// tunneling class, therefore such requests should be rejected.
121-
if (d.opts.Mapping == MappingForward && yPos == 1) ||
122-
(d.opts.Mapping == MappingReverse && yPos == 0) {
123-
81+
if category == 0 {
12482
m := new(dns.Msg)
12583
m.SetRcode(r, dns.RcodeRefused)
12684
w.WriteMsg(m)
@@ -131,59 +89,6 @@ func (d *Dnstun) ServeDNS(ctx context.Context, w dns.ResponseWriter, r *dns.Msg)
13189
return dns.RcodeSuccess, nil
13290
}
13391

134-
// PredictRequest is a request to get predictions for the given attribute vectors.
135-
type PredictRequest struct {
136-
X [][]int `json:"x"`
137-
}
138-
139-
// PredictResponse lists probabilities for each attribute vector.
140-
type PredictResponse struct {
141-
Y [][]float64 `json:"y"`
142-
}
143-
144-
func (d *Dnstun) do(ctx context.Context, method string, u *url.URL, in, out interface{}) error {
145-
var (
146-
b []byte
147-
err error
148-
)
149-
150-
if in != nil {
151-
b, err = json.Marshal(in)
152-
if err != nil {
153-
return errors.Wrapf(err, "failed to encode request")
154-
}
155-
}
156-
req, err := http.NewRequest(method, u.String(), bytes.NewReader(b))
157-
if err != nil {
158-
return err
159-
}
160-
resp, err := d.client.Do(req.WithContext(ctx))
161-
if err != nil {
162-
return err
163-
}
164-
165-
// Decode the list of nodes from the body of the response.
166-
defer resp.Body.Close()
167-
168-
// If server returned non-zero status, the response body is treated
169-
// as a error message, which will be returned to the user.
170-
if resp.StatusCode != http.StatusOK {
171-
// Server could return a response error within a header.
172-
errorCode := resp.Header.Get(http.CanonicalHeaderKey("Error-Code"))
173-
if errorCode != "" {
174-
return errors.New(errorCode)
175-
}
176-
return errors.Errorf("unexpected response from server: %d", resp.StatusCode)
177-
}
178-
179-
if out == nil {
180-
return nil
181-
}
182-
183-
decoder := json.NewDecoder(resp.Body)
184-
return errors.Wrapf(decoder.Decode(out), "failed to decode response")
185-
}
186-
18792
type chainHandler struct {
18893
plugin.Handler
18994
next plugin.Handler
@@ -204,3 +109,70 @@ func (p chainHandler) ServeDNS(ctx context.Context, w dns.ResponseWriter, r *dns
204109
state := request.Request{W: w, Req: r}
205110
return plugin.NextOrFailure(state.Name(), p.next, ctx, w, r)
206111
}
112+
113+
func newArgmax(shape []int64, dtype tf.DataType, dim int64) (graph *tf.Graph) {
114+
inShape := tf.MakeShape(shape...)
115+
root := tfop.NewScope()
116+
117+
input := tfop.Placeholder(root, dtype, tfop.PlaceholderShape(inShape))
118+
tfop.ArgMax(root, input, tfop.Const(root, dim))
119+
120+
graph, err := root.Finalize()
121+
if err != nil {
122+
panic(err)
123+
}
124+
return graph
125+
}
126+
127+
type execGraph struct {
128+
graphInput tf.Output
129+
graphOutput tf.Output
130+
graph *tf.Graph
131+
}
132+
133+
func newExecGraph(graph *tf.Graph) execGraph {
134+
var (
135+
ops = graph.Operations()
136+
input tf.Output
137+
)
138+
139+
for _, o := range ops {
140+
if o.Type() == "Placeholder" {
141+
input = o.Output(0)
142+
break
143+
}
144+
}
145+
146+
if input == (tf.Output{}) {
147+
panic("graph without input")
148+
}
149+
return execGraph{
150+
graphInput: input,
151+
graphOutput: ops[len(ops)-1].Output(0),
152+
graph: graph,
153+
}
154+
}
155+
156+
func (e execGraph) Exec(in *tf.Tensor) (output []*tf.Tensor, err error) {
157+
sess, err := tf.NewSession(e.graph, nil)
158+
if err != nil {
159+
return nil, err
160+
}
161+
162+
defer func() {
163+
e := sess.Close()
164+
if e != nil {
165+
if err != nil {
166+
err = errors.WithMessage(err, e.Error())
167+
} else {
168+
err = e
169+
}
170+
}
171+
}()
172+
173+
return sess.Run(
174+
map[tf.Output]*tf.Tensor{e.graphInput: in},
175+
[]tf.Output{e.graphOutput},
176+
nil,
177+
)
178+
}

0 commit comments

Comments
 (0)