-
Notifications
You must be signed in to change notification settings - Fork 6
/
Copy pathhttps_test.go
123 lines (106 loc) · 3.35 KB
/
https_test.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
package https
import (
"context"
"errors"
"testing"
"github.com/coredns/coredns/plugin/pkg/dnstest"
"github.com/coredns/coredns/plugin/test"
"github.com/miekg/dns"
"github.com/stretchr/testify/require"
)
func newRequestDNSMsg() *dns.Msg {
return &dns.Msg{Question: []dns.Question{
{
Name: "example.com.",
Qtype: dns.TypeA,
Qclass: dns.ClassINET,
},
}}
}
func TestHTTPS(t *testing.T) {
dnsMsg := newRequestDNSMsg()
dnsdata, err := dnsMsg.Pack()
require.NoError(t, err)
dnsClient := &mockDNSClient{reqBody: dnsdata, t: t}
h := newHTTPS(".", dnsClient)
rec := dnstest.NewRecorder(&test.ResponseWriter{})
status, err := h.ServeDNS(context.Background(), rec, dnsMsg)
require.NoError(t, err)
require.Equal(t, dns.RcodeSuccess, status)
require.Equal(t, 1, dnsClient.callCount, "dnsClient call count is wrong")
require.Equal(t, newExpectedDNSMsg(), rec.Msg)
}
type mockDNSClientFunc func(ctx context.Context, dnsreq []byte) (*dns.Msg, error)
func (f mockDNSClientFunc) Query(ctx context.Context, dnsreq []byte) (*dns.Msg, error) {
return f(ctx, dnsreq)
}
type mockDNSResponseWriter struct {
dns.ResponseWriter
writeFunc func(*dns.Msg) error
}
func (w *mockDNSResponseWriter) WriteMsg(msg *dns.Msg) error {
return w.writeFunc(msg)
}
func TestHTTPSMsgPackError(t *testing.T) {
dnsMsg := &dns.Msg{MsgHdr: dns.MsgHdr{Rcode: 0xFFFFF}}
dnsClient := mockDNSClientFunc(func(_ context.Context, _ []byte) (result *dns.Msg, err error) {
t.Fatal("dns client must not be called")
return
})
h := newHTTPS(".", dnsClient)
w := &mockDNSResponseWriter{
ResponseWriter: &test.ResponseWriter{},
writeFunc: func(*dns.Msg) (err error) {
t.Fatal("dns response writer must not be called")
return
},
}
status, err := h.ServeDNS(context.Background(), w, dnsMsg)
require.Error(t, err)
require.Equal(t, dns.RcodeServerFailure, status)
}
func TestHTTPSDNSClientError(t *testing.T) {
dnsMsg := newRequestDNSMsg()
dnsClient := mockDNSClientFunc(func(_ context.Context, _ []byte) (*dns.Msg, error) {
return newExpectedDNSMsg(), errors.New("dns client error")
})
h := newHTTPS(".", dnsClient)
w := &mockDNSResponseWriter{
ResponseWriter: &test.ResponseWriter{},
writeFunc: func(*dns.Msg) (err error) {
t.Fatal("dns response writer must not be called")
return
},
}
status, err := h.ServeDNS(context.Background(), w, dnsMsg)
require.Error(t, err)
require.Equal(t, dns.RcodeServerFailure, status)
}
func TestHTTPSResponseWriterError(t *testing.T) {
dnsMsg := newRequestDNSMsg()
dnsClient := mockDNSClientFunc(func(_ context.Context, _ []byte) (*dns.Msg, error) {
return newExpectedDNSMsg(), nil
})
h := newHTTPS(".", dnsClient)
w := &mockDNSResponseWriter{
ResponseWriter: &test.ResponseWriter{},
writeFunc: func(*dns.Msg) (err error) {
return errors.New("response writer error")
},
}
_, err := h.ServeDNS(context.Background(), w, dnsMsg)
require.Error(t, err)
}
func TestHTTPSDNSResponseStateNotMatch(t *testing.T) {
dnsMsg := newRequestDNSMsg()
dnsClient := mockDNSClientFunc(func(_ context.Context, _ []byte) (*dns.Msg, error) {
result := newExpectedDNSMsg()
result.Question[0].Name = "other.domain."
return result, nil
})
h := newHTTPS(".", dnsClient)
rec := dnstest.NewRecorder(&test.ResponseWriter{})
_, err := h.ServeDNS(context.Background(), rec, dnsMsg)
require.NoError(t, err)
require.Equal(t, dns.RcodeFormatError, rec.Rcode)
}