Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ require (
github.com/gin-gonic/gin v1.11.0
github.com/miekg/dns v1.1.68
github.com/prometheus/client_golang v1.23.2
github.com/stretchr/testify v1.11.1
golang.org/x/crypto v0.43.0
)

Expand All @@ -16,6 +17,7 @@ require (
github.com/bits-and-blooms/bitset v1.24.3 // indirect
github.com/bytedance/gopkg v0.1.3 // indirect
github.com/cespare/xxhash/v2 v2.3.0 // indirect
github.com/davecgh/go-spew v1.1.1 // indirect
github.com/dgryski/go-metro v0.0.0-20250106013310-edb8663e5e33 // indirect
github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f // indirect
github.com/earthboundkid/versioninfo/v2 v2.24.1 // indirect
Expand All @@ -29,18 +31,21 @@ require (
github.com/montanaflynn/stats v0.7.1 // indirect
github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822 // indirect
github.com/oapi-codegen/runtime v1.1.2 // indirect
github.com/pmezard/go-difflib v1.0.0 // indirect
github.com/prometheus/client_model v0.6.2 // indirect
github.com/prometheus/common v0.67.2 // indirect
github.com/prometheus/procfs v0.19.2 // indirect
github.com/rabbitmq/amqp091-go v1.10.0 // indirect
github.com/redis/go-redis/v9 v9.16.0 // indirect
github.com/spf13/pflag v1.0.10 // indirect
github.com/stretchr/objx v0.5.2 // indirect
github.com/xdg-go/pbkdf2 v1.0.0 // indirect
github.com/xdg-go/scram v1.1.2 // indirect
github.com/xdg-go/stringprep v1.0.4 // indirect
github.com/youmark/pkcs8 v0.0.0-20240726163527-a2c0da244d78 // indirect
go.mongodb.org/mongo-driver v1.17.6 // indirect
go.yaml.in/yaml/v2 v2.4.3 // indirect
gopkg.in/yaml.v3 v3.0.1 // indirect
)

require (
Expand Down
1 change: 1 addition & 0 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -220,6 +220,7 @@ github.com/spkg/bom v0.0.0-20160624110644-59b7046e48ad/go.mod h1:qLr4V1qq6nMqFKk
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw=
github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo=
github.com/stretchr/objx v0.5.2 h1:xuMeJ0Sdp5ZMRXx/aWO6RZxdr3beISkG5/G/aIRr3pY=
github.com/stretchr/objx v0.5.2/go.mod h1:FRsXN1f5AsAjCGJKqEizvkpNtU+EGNCLh3NxZ/8L+MA=
github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI=
github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
Expand Down
5 changes: 4 additions & 1 deletion internal/app.go
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,10 @@ func (app *App) RunServer() error {
return fmt.Errorf("failed to start HTTP server: %w", err)
}

dispatcher := NewDNSDispatcher(app.Upstream, blockList, CACHE_SIZE)
dispatcher, err := NewDNSDispatcher(app.Upstream, blockList, CACHE_SIZE)
if err != nil {
return fmt.Errorf("failed to create dispatcher: %w", err)
}

if app.DevMode {
dnsServer := &dns.Server{
Expand Down
19 changes: 17 additions & 2 deletions internal/dns.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package internal

import (
"errors"
"fmt"
"log"
"math"
Expand All @@ -25,7 +26,7 @@ type DNSDispatcher struct {
uniqueClientsHLL *hyperloglog.Sketch
}

func NewDNSDispatcher(upstream string, blockList *BlockList, maxSize int) *DNSDispatcher {
func NewDNSDispatcher(upstream string, blockList *BlockList, maxSize int) (*DNSDispatcher, error) {

cache := cache.NewCache[string, *dns.Msg]().WithMaxKeys(maxSize).WithLRU()
sketch := hyperloglog.New14()
Expand Down Expand Up @@ -70,7 +71,9 @@ func NewDNSDispatcher(upstream string, blockList *BlockList, maxSize int) *DNSDi
return float64(sketch.Estimate())
})

prometheus.MustRegister(latencyHistogram, errorCounts, cacheStats, requestCounts, uniqueClientsCount)
if err := shouldRegister(latencyHistogram, errorCounts, cacheStats, requestCounts, uniqueClientsCount); err != nil {
return nil, fmt.Errorf("failed to register: %w", err)
}

return &DNSDispatcher{
dnsClient: &dnsClient,
Expand All @@ -82,7 +85,19 @@ func NewDNSDispatcher(upstream string, blockList *BlockList, maxSize int) *DNSDi
errorCounts: errorCounts,
requestCounts: requestCounts,
uniqueClientsHLL: sketch,
}, nil
}

func shouldRegister(cs ...prometheus.Collector) error {
var are prometheus.AlreadyRegisteredError
for _, coll := range cs {
if err := prometheus.Register(coll); err != nil {
if !errors.As(err, &are) {
return err
}
}
}
return nil
}

func (d *DNSDispatcher) HandleDNSRequest(writer dns.ResponseWriter, req *dns.Msg) {
Expand Down
161 changes: 161 additions & 0 deletions internal/dns_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,161 @@
package internal

import (
"net"
"testing"

"github.com/miekg/dns"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/mock"
)

// MockResponseWriter is a mock implementation of dns.ResponseWriter.
type MockResponseWriter struct {
mock.Mock
WrittenMsg *dns.Msg
}

func (m *MockResponseWriter) LocalAddr() net.Addr {
return &net.TCPAddr{
IP: net.ParseIP("192.0.2.10"),
Port: 8080,
}
}

func (m *MockResponseWriter) RemoteAddr() net.Addr {
return &net.TCPAddr{
IP: net.ParseIP("192.0.2.10"),
Port: 8080,
}
}

func (m *MockResponseWriter) WriteMsg(msg *dns.Msg) error {
m.WrittenMsg = msg
args := m.Called(msg)
return args.Error(0)
}

func (m *MockResponseWriter) Write(b []byte) (int, error) {
return len(b), nil
}

func (m *MockResponseWriter) Close() error {
return nil
}

func (m *MockResponseWriter) TsigStatus() error {
return nil
}

func (m *MockResponseWriter) TsigTimersOnly(b bool) {
}

func (m *MockResponseWriter) Hijack() {
}

var blockList = NewBlockList([]string{"ads.0xbt.net"}, 0.0001)
var upstream = "8.8.8.8:53"
Comment on lines +56 to +57
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Using global variables for test setup can introduce side effects between tests, especially if tests are run in parallel in the future (t.Parallel()). It's a best practice to define these variables within each test function or a test-specific setup function to ensure complete test isolation.


func TestDNSDispatcher_HandleDNSRequest_Allowed(t *testing.T) {
dispatcher, err := NewDNSDispatcher(upstream, blockList, 100)
assert.NoError(t, err)

req := new(dns.Msg)
req.SetQuestion("google.com.", dns.TypeA)

writer := new(MockResponseWriter)
writer.On("WriteMsg", mock.Anything).Return(nil)

// Call the method under test
dispatcher.HandleDNSRequest(writer, req)

// Assert that the response writer was called with a non-nil message
assert.NotNil(t, writer.WrittenMsg)
assert.Equal(t, dns.RcodeSuccess, writer.WrittenMsg.Rcode)
}

func TestDNSDispatcher_HandleDNSRequest_Blocked(t *testing.T) {
dispatcher, err := NewDNSDispatcher(upstream, blockList, 100)
assert.NoError(t, err)

req := new(dns.Msg)
req.SetQuestion("ads.0xbt.net.", dns.TypeA)

writer := new(MockResponseWriter)
writer.On("WriteMsg", mock.Anything).Return(nil)

// Call the method under test
dispatcher.HandleDNSRequest(writer, req)

// Assert that the response has an NXDOMAIN Rcode
assert.NotNil(t, writer.WrittenMsg)
assert.Equal(t, dns.RcodeNameError, writer.WrittenMsg.Rcode)
}

func TestDNSDispatcher_HandleDNSRequest_MultipleQuestions(t *testing.T) {
dispatcher, err := NewDNSDispatcher(upstream, blockList, 100)
assert.NoError(t, err)

req := new(dns.Msg)
req.Question = []dns.Question{
{Name: "google.com.", Qtype: dns.TypeA, Qclass: dns.ClassINET},
{Name: "ads.0xbt.net.", Qtype: dns.TypeA, Qclass: dns.ClassINET},
}

writer := new(MockResponseWriter)
writer.On("WriteMsg", mock.Anything).Return(nil)

// Call the method under test
dispatcher.HandleDNSRequest(writer, req)

// FIXME: this is the correct/expected behaviour .. to be fixed in #3
// ==================================================================
// Assert that the response writer was called with a non-nil message
// assert.NotNil(t, writer.WrittenMsg)
// assert.Equal(t, dns.RcodeSuccess, writer.WrittenMsg.Rcode)
// assert.Len(t, writer.WrittenMsg.Answer, 1)
// assert.Len(t, writer.WrittenMsg.Question, 2)

// FIXME: current behaviour (to be removed)
// ================================================
// Assert that the response has an NXDOMAIN Rcode because one of the questions is blocked
assert.NotNil(t, writer.WrittenMsg)
assert.Equal(t, dns.RcodeNameError, writer.WrittenMsg.Rcode)
assert.Len(t, writer.WrittenMsg.Answer, 0)
assert.Len(t, writer.WrittenMsg.Question, 1)
}

func TestDNSDispatcher_HandleDNSRequest_CacheHit(t *testing.T) {
dispatcher, err := NewDNSDispatcher(upstream, blockList, 100)
assert.NoError(t, err)

req := new(dns.Msg)
req.SetQuestion("example.com.", dns.TypeA)

writer := new(MockResponseWriter)
writer.On("WriteMsg", mock.Anything).Return(nil)

// First request: should be a cache miss and populate the cache
dispatcher.HandleDNSRequest(writer, req)
assert.NotNil(t, writer.WrittenMsg)
assert.Equal(t, dns.RcodeSuccess, writer.WrittenMsg.Rcode)

// Assert cache stats
stats := dispatcher.cache.Stat()
assert.Equal(t, 0, stats.Hits, "Expected 0 cache hit")
assert.Equal(t, 1, stats.Misses, "Expected 1 cache miss")

// Reset mock for the second request
writer = new(MockResponseWriter)
writer.On("WriteMsg", mock.Anything).Return(nil)

// Second request: should be a cache hit
dispatcher.HandleDNSRequest(writer, req)
assert.NotNil(t, writer.WrittenMsg)
assert.Equal(t, dns.RcodeSuccess, writer.WrittenMsg.Rcode)

// Assert cache stats
stats = dispatcher.cache.Stat()
assert.Equal(t, 1, stats.Hits, "Expected 1 cache hit")
assert.Equal(t, 1, stats.Misses, "Expected 1 cache miss")
}