diff --git a/.golangci.yml b/.golangci.yml
index 7021358f..92d59bbd 100644
--- a/.golangci.yml
+++ b/.golangci.yml
@@ -9,6 +9,10 @@ issues:
- dupl
- goerr113
- containedctx
+ - path: internal/dnssec/.*\.go
+ linters:
+ - gocognit
+ - cyclop
- path: main\.go
text: File is not `goimports`-ed
linters:
diff --git a/Dockerfile b/Dockerfile
index 92037650..6bafd4df 100644
--- a/Dockerfile
+++ b/Dockerfile
@@ -84,6 +84,7 @@ ENV \
MIDDLEWARE_LOG_RESPONSES=off \
MIDDLEWARE_LOCALDNS_ENABLED=on \
MIDDLEWARE_LOCALDNS_RESOLVERS= \
+ DNSSEC_VALIDATION=on \
CACHE_TYPE=lru \
CACHE_LRU_MAX_ENTRIES=10000 \
BLOCK_MALICIOUS=on \
@@ -116,5 +117,4 @@ LABEL \
COPY --from=build --chown=1000 /tmp/gobuild/entrypoint /entrypoint
# Downloads and install some files
-# TODO once DNSSEC is operational
# RUN /entrypoint build
diff --git a/README.md b/README.md
index 6d721ad6..1d90a8be 100644
--- a/README.md
+++ b/README.md
@@ -1,10 +1,10 @@
-# DNS over TLS or HTTPs forwarding resolver
+# DNS over TLS or HTTPs forwarding security aware resolver
-Resolver communicating with public DNS recursive servers over encrypted channels with TLS or HTTPs.
-It also does **caching**, **filtering**, **split-horizon DNS**, **IPv6**, **Prometheus metrucs**.
+Security aware resolver communicating with public DNS recursive servers over encrypted channels with TLS or HTTPs.
+It also does **caching**, **filtering**, **split-horizon DNS**, **IPv6**, **DNSSEC** and **Prometheus metrucs**.
It's fully coded in Go and is a single and cross platform binary program.
-**Announcement**: *I am currently working on a DNSSEC validator implementation to reach feature parity with the v1.x.x image using Unbound*
+**Announcement**: *DNSSEC validation is now implemented, finally reaching feature parity with the v1.x.x image using Unbound*
**The `:v2.0.0-beta` Docker image breaks compatibility with previous images based on v1.x.x versions**
@@ -54,6 +54,7 @@ It's fully coded in Go and is a single and cross platform binary program.
- auto-update [block lists](https://github.com/qdm12/files) periodically with minimal downtime
- Specify custom hostnames and IP addresses
- DNS rebinding protection
+- [DNSSEC validation](https://github.com/qdm12/dns/blob/v2.0.0-beta/internal/dnssec/readme.md) ✅
- [Prometheus Metrics](https://github.com/qdm12/dns/blob/v2.0.0-beta/readme/metrics)
- Container specific features 🐋
- Tiny **16MB** Docker image (uncompressed, amd64) based on the empty image [scratch](https://hub.docker.com/_/scratch)
@@ -134,6 +135,7 @@ For example, the environment variable `UPSTREAM_TYPE` corresponds to the CLI fla
| `LISTENING_ADDRESS` | `:53` | DNS server listening address |
| `CACHE_TYPE` | `lru` | `lru` or `noop`. LRU caches DNS responses by least recently used |
| `CACHE_LRU_MAX_ENTRIES` | `10000` | Number of elements to keep in the LRU cache. |
+| `DNSSEC_VALIDATION` | `on` | `on` or `off`. Enable or disable DNSSEC validation |
| `METRICS_TYPE` | `noop` | `noop` or `prometheus` |
| `METRICS_PROMETHEUS_ADDRESS` | `:9090` | HTTP Prometheus server listening address |
| `METRICS_PROMETHEUS_SUBSYSTEM` | `dns` | Prometheus metrics prefix/subsystem |
diff --git a/cmd/dns/main.go b/cmd/dns/main.go
index 3ba386a3..2ccf47de 100644
--- a/cmd/dns/main.go
+++ b/cmd/dns/main.go
@@ -155,7 +155,8 @@ func _main(ctx context.Context, buildInfo models.BuildInformation, //nolint:cycl
return fmt.Errorf("cache: %w", err)
}
- dnsLoop, err := dns.New(settings, dnsLogger, blockBuilder, cache, prometheusRegistry)
+ dnsLoop, err := dns.New(settings, dnsLogger, blockBuilder, cache, prometheusRegistry,
+ *settings.DNSSEC.Enabled)
if err != nil {
return fmt.Errorf("creating DNS loop: %w", err)
}
diff --git a/internal/config/dnssec.go b/internal/config/dnssec.go
new file mode 100644
index 00000000..33983771
--- /dev/null
+++ b/internal/config/dnssec.go
@@ -0,0 +1,39 @@
+package config
+
+import (
+ "github.com/qdm12/gosettings"
+ "github.com/qdm12/gosettings/reader"
+ "github.com/qdm12/gotree"
+)
+
+type DNSSEC struct {
+ Enabled *bool
+}
+
+func (d *DNSSEC) setDefaults() {
+ d.Enabled = gosettings.DefaultPointer(d.Enabled, true)
+}
+
+func (d DNSSEC) validate() (err error) {
+ return nil
+}
+
+func (d *DNSSEC) String() string {
+ return d.ToLinesNode().String()
+}
+
+func (d *DNSSEC) ToLinesNode() (node *gotree.Node) {
+ if !*d.Enabled {
+ return gotree.New("DNSSEC validation: disabled")
+ }
+ return gotree.New("DNSSEC validation: enabled")
+}
+
+func (d *DNSSEC) read(reader *reader.Reader) (err error) {
+ d.Enabled, err = reader.BoolPtr("DNSSEC_VALIDATION")
+ if err != nil {
+ return err
+ }
+
+ return nil
+}
diff --git a/internal/config/settings.go b/internal/config/settings.go
index 4926d05b..3c789d02 100644
--- a/internal/config/settings.go
+++ b/internal/config/settings.go
@@ -32,6 +32,7 @@ type Settings struct {
MiddlewareLog MiddlewareLog
Metrics Metrics
LocalDNS LocalDNS
+ DNSSEC DNSSEC
CheckDNS *bool
UpdatePeriod *time.Duration
}
@@ -47,6 +48,7 @@ func (s *Settings) SetDefaults() {
s.MiddlewareLog.setDefaults()
s.Metrics.setDefaults()
s.LocalDNS.setDefault()
+ s.DNSSEC.setDefaults()
s.CheckDNS = gosettings.DefaultPointer(s.CheckDNS, true)
const defaultUpdaterPeriod = 24 * time.Hour
s.UpdatePeriod = gosettings.DefaultPointer(s.UpdatePeriod, defaultUpdaterPeriod)
@@ -77,6 +79,7 @@ func (s *Settings) Validate() (err error) {
"middleware log": s.MiddlewareLog.validate,
"metrics": s.Metrics.validate,
"local DNS": s.LocalDNS.validate,
+ "DNSSEC": s.DNSSEC.validate,
}
for name, validate := range nameToValidate {
err = validate()
@@ -119,6 +122,7 @@ func (s *Settings) ToLinesNode() (node *gotree.Node) {
node.AppendNode(s.MiddlewareLog.ToLinesNode())
node.AppendNode(s.Metrics.ToLinesNode())
node.AppendNode(s.LocalDNS.ToLinesNode())
+ node.AppendNode(s.DNSSEC.ToLinesNode())
node.Appendf("Check DNS: %s", gosettings.BoolToYesNo(s.CheckDNS))
if *s.UpdatePeriod == 0 {
@@ -130,6 +134,7 @@ func (s *Settings) ToLinesNode() (node *gotree.Node) {
return node
}
+//nolint:cyclop
func (s *Settings) Read(reader *reader.Reader, warner Warner) (err error) {
warnings := checkOutdatedEnv(reader)
for _, warning := range warnings {
@@ -173,6 +178,11 @@ func (s *Settings) Read(reader *reader.Reader, warner Warner) (err error) {
return fmt.Errorf("local DNS settings: %w", err)
}
+ err = s.DNSSEC.read(reader)
+ if err != nil {
+ return fmt.Errorf("DNSSEC settings: %w", err)
+ }
+
s.CheckDNS, err = reader.BoolPtr("CHECK_DNS")
if err != nil {
return err
diff --git a/internal/dns/loop.go b/internal/dns/loop.go
index 1d131d6f..997e38b2 100644
--- a/internal/dns/loop.go
+++ b/internal/dns/loop.go
@@ -21,6 +21,7 @@ type Loop struct {
blockBuilder BlockBuilder
cache Cache
prometheusRegistry PrometheusRegistry
+ dnssecEnabled bool
dnsServer Service
updateTimer *time.Timer
@@ -31,7 +32,8 @@ type Loop struct {
func New(settings config.Settings, logger Logger,
blockBuilder BlockBuilder, cache Cache,
- prometheusRegistry PrometheusRegistry) (loop *Loop, err error) {
+ prometheusRegistry PrometheusRegistry, dnssecEnabled bool) (
+ loop *Loop, err error) {
settings.SetDefaults()
err = settings.Validate()
if err != nil {
@@ -44,6 +46,7 @@ func New(settings config.Settings, logger Logger,
blockBuilder: blockBuilder,
cache: cache,
prometheusRegistry: prometheusRegistry,
+ dnssecEnabled: dnssecEnabled,
}, nil
}
@@ -209,7 +212,7 @@ func (l *Loop) setupAll(ctx context.Context, downloadBlockFiles bool) ( //nolint
}
server, err := setup.DNS(l.settings, l.ipv6Support, l.cache,
- filter, l.logger, l.prometheusRegistry)
+ filter, l.logger, l.prometheusRegistry, l.dnssecEnabled)
if err != nil {
return nil, err
}
diff --git a/internal/dnssec/chain.go b/internal/dnssec/chain.go
new file mode 100644
index 00000000..91554191
--- /dev/null
+++ b/internal/dnssec/chain.go
@@ -0,0 +1,149 @@
+package dnssec
+
+import (
+ "errors"
+ "fmt"
+ "strings"
+
+ "github.com/miekg/dns"
+)
+
+// buildDelegationChain queries the RRs required for the zone validation.
+// It begins the queries at the root zone and then go down the delegation
+// chain until it reaches the desired zone, or an unsigned zone.
+// It returns a delegation chain of signed zones where the
+// first signed zone (index 0) is the root zone and the last signed
+// zone is the last signed zone, which can be the desired zone.
+func buildDelegationChain(handler dns.Handler, desiredZone string, qClass uint16) (
+ delegationChain []signedData, err error) {
+ zoneNames := desiredZoneToZoneNames(desiredZone)
+ delegationChain = make([]signedData, 0, len(zoneNames))
+
+ for _, zoneName := range zoneNames {
+ // zoneName iterates in this order: ., com., example.com.
+ data, signed, err := queryDelegation(handler, zoneName, qClass)
+ if err != nil {
+ return nil, fmt.Errorf("querying delegation for desired zone %s: %w",
+ desiredZone, err)
+ }
+ delegationChain = append(delegationChain, data)
+ if !signed {
+ // first zone without a DS RRSet, but it should
+ // have at least one NSEC or NSEC3 RRSet, even for
+ // NXDOMAIN responses.
+ break
+ }
+ }
+
+ return delegationChain, nil
+}
+
+func desiredZoneToZoneNames(desiredZone string) (zoneNames []string) {
+ if desiredZone == "." {
+ return []string{"."}
+ }
+
+ zoneParts := strings.Split(desiredZone, ".")
+ zoneNames = make([]string, len(zoneParts))
+ for i := range zoneParts {
+ zoneNames[i] = dns.Fqdn(strings.Join(zoneParts[len(zoneParts)-1-i:], "."))
+ }
+ return zoneNames
+}
+
+// queryDelegation obtains the DS RRSet and the DNSKEY RRSet
+// for a given zone and class, and creates a signed zone with
+// this information. It does not query the (non existent)
+// DS record for the root zone, which is the trust root anchor.
+func queryDelegation(handler dns.Handler, zone string, qClass uint16) (
+ data signedData, signed bool, err error) {
+ data.zone = zone
+ data.class = qClass
+
+ // TODO set root zone DS here!
+
+ // do not query DS for root zone since its DS record
+ // is the trust root anchor.
+ if zone != "." {
+ data.dsResponse, err = queryDS(handler, zone, qClass)
+ if err != nil {
+ return signedData{}, false, fmt.Errorf("querying DS record: %w", err)
+ }
+
+ if data.dsResponse.isNoData() || data.dsResponse.isNXDomain() {
+ // If no DS RRSet is found, the entire zone is unsigned.
+ // This also means no DNSKEY RRSet exists, since child zones are
+ // also unsigned, so return with the error errZoneHasNoDSRcord
+ // to signal the caller to stop the delegation chain queries for
+ // child zones when encountering a zone with no DS RRSet.
+ return data, false, nil
+ }
+ }
+
+ data.dnsKeyResponse, err = queryDNSKeys(handler, zone, qClass)
+ if err != nil {
+ return signedData{}, true, fmt.Errorf("querying DNSKEY record: %w", err)
+ }
+
+ return data, true, nil
+}
+
+var (
+ ErrDSAndNSECAbsent = errors.New("zone has no DS record and no NSEC record")
+)
+
+func queryDS(handler dns.Handler, zone string, qClass uint16) (
+ response dnssecResponse, err error) {
+ response, err = queryRRSets(handler, zone, qClass, dns.TypeDS)
+ switch {
+ case err != nil:
+ return dnssecResponse{}, err
+ case !response.isSigned():
+ // no signed DS answer and no NSEC/NSEC3 authority RR
+ return dnssecResponse{}, wrapError(
+ zone, qClass, dns.TypeDS, ErrDSAndNSECAbsent)
+ case response.isNXDomain(), response.isNoData():
+ // there is one or more NSEC/NSEC3 authority RRSets.
+ return response, nil
+ }
+ // signed answer RRSet(s)
+
+ // Double check we only have 1 DS RRSet.
+ // TODO remove?
+ err = dnssecRRSetsIsSingleOfType(response.answerRRSets, dns.TypeDS)
+ if err != nil {
+ return dnssecResponse{},
+ wrapError(zone, qClass, dns.TypeDS, err)
+ }
+
+ return response, nil
+}
+
+// queryDNSKeys queries the DNSKEY records for a given signed zone
+// containing a DS RRSet. It returns an error if the DNSKEY RRSet is
+// missing or is unsigned.
+// Note this returns all the DNSKey RRs, even non-zone ones.
+func queryDNSKeys(handler dns.Handler, qname string, qClass uint16) (
+ response dnssecResponse, err error) {
+ // DNSKey RRSet(s) should be present so the NSEC/NSEC3 RRSet is ignored.
+ response, err = queryRRSets(handler, qname, qClass, dns.TypeDNSKEY)
+ switch {
+ case err != nil:
+ return dnssecResponse{}, err
+ case !response.isSigned(), response.isNoData(): // cannot be NXDOMAIN
+ // no signed DNSKEY answer
+ return dnssecResponse{}, fmt.Errorf("for %s: %w",
+ nameClassTypeToString(qname, qClass, dns.TypeDNSKEY),
+ ErrDNSKeyNotFound)
+ }
+
+ // Double check we only have 1 DNSKEY RRSet.
+ // TODO remove?
+ err = dnssecRRSetsIsSingleOfType(response.answerRRSets, dns.TypeDNSKEY)
+ if err != nil {
+ return dnssecResponse{},
+ wrapError(qname, qClass, dns.TypeDNSKEY, err)
+ }
+
+ return response, nil
+}
diff --git a/internal/dnssec/chain_test.go b/internal/dnssec/chain_test.go
new file mode 100644
index 00000000..fd86ed50
--- /dev/null
+++ b/internal/dnssec/chain_test.go
@@ -0,0 +1,39 @@
+package dnssec
+
+import (
+ "testing"
+
+ "github.com/stretchr/testify/assert"
+)
+
+func Test_desiredZoneToZoneNames(t *testing.T) {
+ t.Parallel()
+
+ testCases := map[string]struct {
+ desiredZone string
+ zoneNames []string
+ }{
+ "root": {
+ desiredZone: ".",
+ zoneNames: []string{"."},
+ },
+ "com": {
+ desiredZone: "com.",
+ zoneNames: []string{".", "com."},
+ },
+ "example.com": {
+ desiredZone: "example.com.",
+ zoneNames: []string{".", "com.", "example.com."},
+ },
+ }
+
+ for name, testCase := range testCases {
+ testCase := testCase
+ t.Run(name, func(t *testing.T) {
+ t.Parallel()
+
+ zoneNames := desiredZoneToZoneNames(testCase.desiredZone)
+ assert.Equal(t, testCase.zoneNames, zoneNames)
+ })
+ }
+}
diff --git a/internal/dnssec/cname.go b/internal/dnssec/cname.go
new file mode 100644
index 00000000..0224a94a
--- /dev/null
+++ b/internal/dnssec/cname.go
@@ -0,0 +1,25 @@
+package dnssec
+
+import (
+ "fmt"
+
+ "github.com/miekg/dns"
+)
+
+func mustRRToCNAME(rr dns.RR) *dns.CNAME {
+ cname, ok := rr.(*dns.CNAME)
+ if !ok {
+ panic(fmt.Sprintf("RR is of type %T and not of type *dns.CNAME", rr))
+ }
+ return cname
+}
+
+func getCnameTarget(rrSets []dnssecRRSet) (target string) {
+ for _, rrSet := range rrSets {
+ if rrSet.qtype() == dns.TypeCNAME {
+ cname := mustRRToCNAME(rrSet.rrSet[0])
+ return cname.Target
+ }
+ }
+ return ""
+}
diff --git a/internal/dnssec/debug.go b/internal/dnssec/debug.go
new file mode 100644
index 00000000..46bf59e7
--- /dev/null
+++ b/internal/dnssec/debug.go
@@ -0,0 +1,7 @@
+package dnssec
+
+import "github.com/qdm12/log"
+
+//nolint:gochecknoglobals
+var globalDebugLogger = log.New(log.SetCallerFile(true),
+ log.SetCallerLine(true), log.SetComponent("dnssec-debug"), log.SetLevel(log.LevelDebug))
diff --git a/internal/dnssec/dnskey.go b/internal/dnssec/dnskey.go
new file mode 100644
index 00000000..b9e46b89
--- /dev/null
+++ b/internal/dnssec/dnskey.go
@@ -0,0 +1,71 @@
+package dnssec
+
+import (
+ "fmt"
+
+ "github.com/miekg/dns"
+)
+
+func mustRRToDNSKey(rr dns.RR) *dns.DNSKEY {
+ dnsKey, ok := rr.(*dns.DNSKEY)
+ if !ok {
+ panic(fmt.Sprintf("RR is of type %T and not of type *dns.DNSKEY", rr))
+ }
+ return dnsKey
+}
+
+// makeKeyTagToDNSKey creates a map of key tag to DNSKEY from a DNSKEY RRSet,
+// ignoring any RR which is not a Zone signing key.
+func makeKeyTagToDNSKey(dnsKeyRRSet []dns.RR) (keyTagToDNSKey map[uint16]*dns.DNSKEY) {
+ keyTagToDNSKey = make(map[uint16]*dns.DNSKEY, len(dnsKeyRRSet))
+ for _, dnsKeyRR := range dnsKeyRRSet {
+ dnsKey := mustRRToDNSKey(dnsKeyRR)
+ if dnsKey.Flags&dns.ZONE == 0 {
+ // As described in https://datatracker.ietf.org/doc/html/rfc4034#section-2.1.1
+ // and https://datatracker.ietf.org/doc/html/rfc4034#section-5.2:
+ // If bit 7 has value 0, then the DNSKEY record holds some other type of DNS
+ // public key and MUST NOT be used to verify RRSIGs that cover RRsets.
+ // The DNSKEY RR Flags MUST have Flags bit 7 set. If the
+ // DNSKEY flags do not indicate a DNSSEC zone key, the DS
+ // RR (and the DNSKEY RR it references) MUST NOT be used
+ // in the validation process.
+ continue
+ }
+ keyTagToDNSKey[dnsKey.KeyTag()] = dnsKey
+ }
+ return keyTagToDNSKey
+}
+
+const (
+ algoPreferenceRecommended uint8 = iota
+ algoPreferenceMust
+ algoPreferenceMay
+ algoPreferenceMustNot
+ algoPreferenceUnknown
+)
+
+// lessDNSKeyAlgorithm returns true if algoID1 < algoID2 in terms
+// of preference. The preference is determined by the table defined in:
+// https://datatracker.ietf.org/doc/html/rfc8624#section-3.1
+func lessDNSKeyAlgorithm(algoID1, algoID2 uint8) bool {
+ return algoIDToPreference(algoID1) < algoIDToPreference(algoID2)
+}
+
+// algoIDToPreference returns the preference level of the algorithm ID.
+// Note this is a function with a switch statement, which not only provide
+// immutability compared to a global variable map, but is also x10 faster
+// than map lookups.
+func algoIDToPreference(algoID uint8) (preference uint8) {
+ switch algoID {
+ case dns.RSAMD5, dns.DSA, dns.DSANSEC3SHA1:
+ return algoPreferenceMustNot
+ case dns.ECCGOST:
+ return algoPreferenceMay
+ case dns.RSASHA1, dns.RSASHA1NSEC3SHA1, dns.RSASHA256, dns.RSASHA512, dns.ECDSAP256SHA256:
+ return algoPreferenceMust
+ case dns.ECDSAP384SHA384, dns.ED25519, dns.ED448:
+ return algoPreferenceRecommended
+ default:
+ return algoPreferenceUnknown
+ }
+}
diff --git a/internal/dnssec/dnskey_test.go b/internal/dnssec/dnskey_test.go
new file mode 100644
index 00000000..616d6e18
--- /dev/null
+++ b/internal/dnssec/dnskey_test.go
@@ -0,0 +1,57 @@
+package dnssec
+
+import "testing"
+
+var testGlobalMap = map[uint8]uint8{ //nolint:gochecknoglobals
+ 1: 1,
+ 2: 2,
+ 3: 3,
+ 4: 4,
+ 5: 5,
+ 6: 6,
+ 7: 7,
+ 8: 8,
+}
+
+func testSwitchStatement(key uint8) uint8 {
+ switch key {
+ case 1:
+ return 1
+ case 2:
+ return 2
+ case 3:
+ return 3
+ case 4:
+ return 4
+ case 5:
+ return 5
+ case 6:
+ return 6
+ case 7:
+ return 7
+ case 8:
+ return 8
+ default:
+ panic("invalid key")
+ }
+}
+
+// This benchmark aims to check if, for algoIDToPreference, it is
+// better to:
+// 1. have a global map variable
+// 2. have a function with a switch statement
+// The second point at equal performance is better due to its
+// immutability nature, unlike 1.
+func Benchmark_globalMap_switch(b *testing.B) {
+ b.Run("global_map", func(b *testing.B) {
+ for i := 0; i < b.N; i++ {
+ _ = testGlobalMap[1]
+ }
+ })
+
+ b.Run("switch", func(b *testing.B) {
+ for i := 0; i < b.N; i++ {
+ _ = testSwitchStatement(1)
+ }
+ })
+}
diff --git a/internal/dnssec/dnssec.go b/internal/dnssec/dnssec.go
new file mode 100644
index 00000000..55c7af99
--- /dev/null
+++ b/internal/dnssec/dnssec.go
@@ -0,0 +1,63 @@
+package dnssec
+
+import (
+ "errors"
+ "fmt"
+
+ "github.com/miekg/dns"
+ "github.com/qdm12/dns/v2/internal/local"
+ "github.com/qdm12/dns/v2/internal/stateful"
+)
+
+var (
+ ErrQuestionsMultiple = errors.New("multiple questions")
+)
+
+func Validate(request *dns.Msg, handler dns.Handler) (response *dns.Msg, err error) {
+ switch len(request.Question) {
+ case 0:
+ response = new(dns.Msg)
+ response.SetRcode(request, dns.RcodeSuccess)
+ return response, nil
+ case 1:
+ default:
+ return nil, fmt.Errorf("%w: %d", ErrQuestionsMultiple, len(request.Question))
+ }
+
+ desiredZone := request.Question[0].Name
+ qType := request.Question[0].Qtype
+ qClass := request.Question[0].Qclass
+
+ if local.IsFQDNLocal(desiredZone) {
+ // Do not perform DNSSEC validation for local zones
+ writer := stateful.NewWriter()
+ handler.ServeDNS(writer, request)
+ return writer.Response, nil
+ }
+
+ desiredResponse, err := queryRRSets(handler, desiredZone, qClass, qType)
+ if err != nil {
+ return nil, fmt.Errorf("running desired query: %w", err)
+ }
+
+ originalDesiredZone := desiredZone
+ cnameTarget := getCnameTarget(desiredResponse.answerRRSets)
+ if cnameTarget != "" {
+ desiredZone = cnameTarget
+ }
+
+ delegationChain, err := buildDelegationChain(handler, desiredZone, qClass)
+ if err != nil {
+ return nil, fmt.Errorf("building delegation chain for %s: %w",
+ originalDesiredZone, err)
+ }
+
+ err = validateWithChain(desiredZone, qType, desiredResponse, delegationChain)
+ if err != nil {
+ return nil, fmt.Errorf("for %s: validating answer RRSets"+
+ " with delegation chain: %w",
+ nameClassTypeToString(originalDesiredZone, qClass, qType), err)
+ }
+
+ return desiredResponse.toDNSMsg(request), nil
+}
diff --git a/internal/dnssec/ds.go b/internal/dnssec/ds.go
new file mode 100644
index 00000000..c048d5e2
--- /dev/null
+++ b/internal/dnssec/ds.go
@@ -0,0 +1,15 @@
+package dnssec
+
+import (
+ "fmt"
+
+ "github.com/miekg/dns"
+)
+
+func mustRRToDS(rr dns.RR) *dns.DS {
+ ds, ok := rr.(*dns.DS)
+ if !ok {
+ panic(fmt.Sprintf("RR is of type %T and not of type *dns.DS", rr))
+ }
+ return ds
+}
diff --git a/internal/dnssec/edns0.go b/internal/dnssec/edns0.go
new file mode 100644
index 00000000..c552adeb
--- /dev/null
+++ b/internal/dnssec/edns0.go
@@ -0,0 +1,28 @@
+package dnssec
+
+import "github.com/miekg/dns"
+
+func newEDNSRequest(zone string, qClass, qType uint16) (request *dns.Msg) {
+ request = new(dns.Msg).SetQuestion(zone, qType)
+ request.Question[0].Qclass = qClass
+ request.RecursionDesired = true
+ const maxUDPSize = 4096
+ const doEdns0 = true
+ request.SetEdns0(maxUDPSize, doEdns0)
+ return request
+}
+
+func isRequestAskingForDNSSEC(request *dns.Msg) bool {
+ opt := request.IsEdns0()
+ if opt == nil {
+ return false
+ }
+
+ // See https://datatracker.ietf.org/doc/html/rfc6891#section-6.2.3
+ const minUDPSize = 512
+
+ return opt.Hdr.Name == "." &&
+ opt.Hdr.Rrtype == dns.TypeOPT &&
+ opt.Hdr.Class >= minUDPSize && // UDP size
+ opt.Do()
+}
diff --git a/internal/dnssec/errors.go b/internal/dnssec/errors.go
new file mode 100644
index 00000000..5d0235f1
--- /dev/null
+++ b/internal/dnssec/errors.go
@@ -0,0 +1,26 @@
+package dnssec
+
+import "errors"
+
+var (
+ // TODO review exported errors usage and all sentinel errors.
+ ErrBogus = errors.New("bogus response")
+)
+
+var _ error = (*joinedErrors)(nil)
+
+type joinedErrors struct { //nolint:errname
+ errs []error
+}
+
+func (e *joinedErrors) add(err error) {
+ e.errs = append(e.errs, err)
+}
+
+func (e *joinedErrors) Error() string {
+ return joinStrings(e.errs, "and")
+}
+
+func (e *joinedErrors) Unwrap() []error {
+ return e.errs
+}
diff --git a/internal/dnssec/helpers.go b/internal/dnssec/helpers.go
new file mode 100644
index 00000000..4e9634a1
--- /dev/null
+++ b/internal/dnssec/helpers.go
@@ -0,0 +1,54 @@
+package dnssec
+
+import (
+ "fmt"
+ "strings"
+
+ "github.com/miekg/dns"
+ "golang.org/x/exp/constraints"
+)
+
+func nameClassTypeToString(qname string, qClass, qType uint16) string {
+ return qname + " " + dns.ClassToString[qClass] + " " + dns.TypeToString[qType]
+}
+
+func nameTypeToString(qname string, qType uint16) string {
+ return qname + " " + dns.TypeToString[qType]
+}
+
+func hashToString(hashType uint8) string {
+ s, ok := dns.HashToString[hashType]
+ if ok {
+ return s
+ }
+ return fmt.Sprintf("%d", hashType)
+}
+
+func hashesToString(hashTypes []uint8) string {
+ hashStrings := make([]string, len(hashTypes))
+ for i, hash := range hashTypes {
+ hashStrings[i] = hashToString(hash)
+ }
+ return strings.Join(hashStrings, ", ")
+}
+
+func integersToString[T constraints.Integer](integers []T) string {
+ integerStrings := make([]string, len(integers))
+ for i, hash := range integers {
+ integerStrings[i] = fmt.Sprint(hash)
+ }
+ return strings.Join(integerStrings, ", ")
+}
+
+func wrapError(zone string, qClass, qType uint16, err error) error {
+ return fmt.Errorf("for %s: %w", nameClassTypeToString(zone, qClass, qType), err)
+}
+
+func isOneOf[T comparable](value T, values ...T) bool {
+ for _, v := range values {
+ if value == v {
+ return true
+ }
+ }
+ return false
+}
diff --git a/internal/dnssec/integration_test.go b/internal/dnssec/integration_test.go
new file mode 100644
index 00000000..b39009b8
--- /dev/null
+++ b/internal/dnssec/integration_test.go
@@ -0,0 +1,258 @@
+//go:build integration
+// +build integration
+
+package dnssec
+
+import (
+ "context"
+ "net"
+ "testing"
+ "time"
+
+ "github.com/miekg/dns"
+ "github.com/qdm12/dns/v2/internal/stateful"
+ "github.com/stretchr/testify/assert"
+ "github.com/stretchr/testify/require"
+)
+
+func Test_Validate(t *testing.T) {
+ t.Parallel()
+
+ testCases := map[string]struct {
+ request *dns.Msg
+ errWrapped error
+ errMessage string
+ }{
+ // "exists_not_signed": {
+ // request: &dns.Msg{
+ // Question: []dns.Question{
+ // {Name: "test.github.com.", Qtype: dns.TypeA, Qclass: dns.ClassINET},
+ // },
+ // },
+ // },
+ // "exists_signed": {
+ // request: &dns.Msg{
+ // Question: []dns.Question{
+ // {Name: "icann.org.", Qtype: dns.TypeA, Qclass: dns.ClassINET},
+ // },
+ // },
+ // },
+ // "nodata_nsec3": {
+ // request: &dns.Msg{
+ // Question: []dns.Question{
+ // {Name: "icann.org.", Qtype: dns.TypeMD, Qclass: dns.ClassINET},
+ // },
+ // },
+ // },
+ // "nxdomain_nsec3": {
+ // request: &dns.Msg{
+ // Question: []dns.Question{
+ // {Name: "xyz.icann.org.", Qtype: dns.TypeA, Qclass: dns.ClassINET},
+ // },
+ // },
+ // },
+ // "nxdomain_nsec": {
+ // request: &dns.Msg{
+ // Question: []dns.Question{
+ // {Name: "x.cloudflare.com.", Qtype: dns.TypeA, Qclass: dns.ClassINET},
+ // },
+ // },
+ // },
+ // "a_and_cname": {
+ // request: &dns.Msg{
+ // Question: []dns.Question{
+ // {Name: "sigok.ippacket.stream.", Qtype: dns.TypeA, Qclass: dns.ClassINET},
+ // },
+ // },
+ // },
+ // //
+ // // Special cases
+ // //
+ // "dnssec_failed_by_upstream": {
+ // // One can also try rhybar.cz.
+ // request: &dns.Msg{
+ // Question: []dns.Question{
+ // {Name: "dnssec-failed.org.", Qtype: dns.TypeA, Qclass: dns.ClassINET},
+ // },
+ // },
+ // errWrapped: ErrRcodeBad,
+ // errMessage: "running desired query: " +
+ // "for dnssec-failed.org. IN A: " +
+ // "bad response rcode: SERVFAIL",
+ // },
+ // "signed_answer_insecure_parent": {
+ // // The answer is a NODATA with an NSEC RRSet signed by whispersystems.org.
+ // // The parent zone whispersystems.org. has DNSKEYs (ZSK+KSK) but
+ // // no DS record, so it is therefore insecure and so is the answer.
+ // request: &dns.Msg{
+ // Question: []dns.Question{
+ // {Name: "textsecure-service.whispersystems.org.", Qtype: dns.TypeA, Qclass: dns.ClassINET},
+ // },
+ // },
+ // },
+ // "nxdomain_2_rrsigs_per_nsec": {
+ // // There are two RRSIGs per NSEC RR, each with a
+ // // different algorithm. This is to allow transitioning
+ // // from one weaker/older algorithm to a stronger/newer one.
+ // request: &dns.Msg{
+ // Question: []dns.Question{
+ // {Name: "xyzzy14.sdsmt.edu.", Qtype: dns.TypeA, Qclass: dns.ClassINET},
+ // },
+ // },
+ // },
+ // "nodata_2_rrsigs_dnskey": {
+ // // The DNSKEY RRSet of vip.icann.org. is signed by two RRSIGs,
+ // // one validating against the ZSK of icann.org. and the other
+ // // validating against the KSK of icann.org. This is valid although
+ // // not very conventional.
+ // request: &dns.Msg{
+ // Question: []dns.Question{
+ // {Name: "vip.icann.org.", Qtype: dns.TypeA, Qclass: dns.ClassINET},
+ // },
+ // },
+ // },
+ "wildcard_expanded": {
+ request: &dns.Msg{
+ Question: []dns.Question{
+ {Name: "b.zahrarestaurant.com.", Qtype: dns.TypeA, Qclass: dns.ClassINET},
+ },
+ },
+ },
+ }
+ for name, testCase := range testCases {
+ testCase := testCase
+ t.Run(name, func(t *testing.T) {
+ t.Parallel()
+
+ testCase.request.RecursionDesired = true
+ testCase.request.Id = dns.Id()
+ requestCopy := testCase.request.Copy()
+
+ handler := newIntegTestHandler(t)
+
+ response, err := Validate(testCase.request, handler)
+
+ require.ErrorIs(t, err, testCase.errWrapped)
+
+ var expectedResponse *dns.Msg
+ if testCase.errWrapped != nil {
+ assert.EqualError(t, err, testCase.errMessage)
+ } else { // no error, fetch expected response
+ statefulWriter := stateful.NewWriter()
+ requestCopy.Id = dns.Id()
+ handler.ServeDNS(statefulWriter, requestCopy)
+ expectedResponse = statefulWriter.Response
+ // DNSSEC does not do recursion for now
+ expectedResponse.RecursionAvailable = false
+ }
+
+ assertResponsesEqual(t, expectedResponse, response)
+ })
+ }
+}
+
+type integTestHandler struct {
+ t *testing.T
+ client *dns.Client
+ dialer *net.Dialer
+}
+
+func newIntegTestHandler(t *testing.T) *integTestHandler {
+ return &integTestHandler{
+ t: t,
+ client: &dns.Client{},
+ dialer: &net.Dialer{},
+ }
+}
+
+func (h *integTestHandler) ServeDNS(w dns.ResponseWriter, request *dns.Msg) {
+ request = request.Copy()
+
+ deadline, ok := h.t.Deadline()
+ if !ok {
+ deadline = time.Now().Add(4 * time.Second)
+ }
+ ctx, cancel := context.WithDeadline(context.Background(), deadline)
+ defer cancel()
+
+ const maxTries = 3
+ success := false
+ var response *dns.Msg
+ for i := 0; i < maxTries; i++ {
+ const timeout = time.Second
+ ctx, cancel := context.WithTimeout(ctx, timeout)
+ defer cancel()
+
+ // Try a new UDP connection on each try
+ netConn, err := h.dialer.DialContext(ctx, "udp", "1.1.1.1:53")
+ require.NoError(h.t, err)
+ dnsConn := &dns.Conn{Conn: netConn}
+
+ response, _, err = h.client.ExchangeWithConnContext(ctx, request, dnsConn)
+ if err != nil {
+ _ = dnsConn.Close()
+ h.t.Logf("try %d of %d: %s", i+1, maxTries, err)
+ continue
+ }
+
+ err = dnsConn.Close()
+ require.NoError(h.t, err)
+
+ success = true
+ break
+ }
+
+ if !success {
+ h.t.Fatalf("could not communicate with DNS server after %d tries", maxTries)
+ }
+
+ if !response.Truncated {
+ // Remove TTL fields from rrset
+ for i := range response.Answer {
+ response.Answer[i].Header().Ttl = 0
+ }
+
+ _ = w.WriteMsg(response)
+ return
+ }
+
+ // Retry with TCP
+ netConn, err := h.dialer.DialContext(ctx, "tcp", "1.1.1.1:53")
+ require.NoError(h.t, err)
+
+ dnsConn := &dns.Conn{Conn: netConn}
+ response, _, err = h.client.ExchangeWithConnContext(ctx, request, dnsConn)
+ require.NoError(h.t, err)
+
+ err = dnsConn.Close()
+ require.NoError(h.t, err)
+
+ _ = w.WriteMsg(response)
+}
+
+func assertResponsesEqual(t *testing.T, a, b *dns.Msg) {
+ if a == nil {
+ require.Nil(t, b)
+ return
+ }
+ require.NotNil(t, b)
+
+ // Remove TTL fields from answer and authority
+ for i := range a.Answer {
+ a.Answer[i].Header().Ttl = 0
+ }
+ for i := range a.Ns {
+ a.Ns[i].Header().Ttl = 0
+ }
+ for i := range b.Answer {
+ b.Answer[i].Header().Ttl = 0
+ }
+ for i := range b.Ns {
+ b.Ns[i].Header().Ttl = 0
+ }
+
+ a.Id = 0
+ b.Id = 0
+
+ assert.Equal(t, a, b)
+}
diff --git a/internal/dnssec/nodata.go b/internal/dnssec/nodata.go
new file mode 100644
index 00000000..80952ea0
--- /dev/null
+++ b/internal/dnssec/nodata.go
@@ -0,0 +1,90 @@
+package dnssec
+
+import (
+ "fmt"
+
+ "github.com/miekg/dns"
+)
+
+// Note: validateNoData works also for the qtype DS since
+// the implementations of nsec3ValidateNoData and
+// nsecValidateNoData take care of redirecting to the
+// DS specific validation functions, but preferably use
+// validateNoDataDS for the qtype DS.
+func validateNoData(qname string, qtype uint16,
+ authoritySection []dnssecRRSet,
+ keyTagToDNSKey map[uint16]*dns.DNSKEY) (err error) {
+ err = verifyRRSetsRRSig(authoritySection, keyTagToDNSKey)
+ if err != nil {
+ return fmt.Errorf("verifying RRSIGs: %w", err)
+ }
+
+ nsec3RRs, wildcard := extractNSEC3s(authoritySection)
+ if len(nsec3RRs) > 0 {
+ nsec3RRs, err = nsec3InitialChecks(nsec3RRs)
+ if err != nil {
+ return fmt.Errorf("initial NSEC3 checks: %w", err)
+ } else if wildcard {
+ return nsec3ValidateNoDataWildcard(qname, qtype, nsec3RRs)
+ }
+ return nsec3ValidateNoData(qname, qtype, nsec3RRs)
+ }
+
+ nsecRRs := extractNSECs(authoritySection)
+ if len(nsecRRs) > 0 {
+ return nsecValidateNoData(qname, qtype, nsecRRs)
+ }
+
+ return fmt.Errorf("verifying no data for %s: %w: "+
+ "no NSEC or NSEC3 record found",
+ nameTypeToString(qname, qtype), ErrBogus)
+}
+
+func validateNoDataDS(qname string,
+ authoritySection []dnssecRRSet,
+ keyTagToDNSKey map[uint16]*dns.DNSKEY) (err error) {
+ err = verifyRRSetsRRSig(authoritySection, keyTagToDNSKey)
+ if err != nil {
+ return fmt.Errorf("verifying RRSIGs: %w", err)
+ }
+
+ nsec3RRs, wildcard := extractNSEC3s(authoritySection)
+ if len(nsec3RRs) > 0 {
+ nsec3RRs, err = nsec3InitialChecks(nsec3RRs)
+ if err != nil {
+ return fmt.Errorf("initial NSEC3 checks: %w", err)
+ } else if wildcard {
+ return nsec3ValidateNoDataWildcard(qname, dns.TypeDS, nsec3RRs)
+ }
+ return nsec3ValidateNoDataDS(qname, nsec3RRs)
+ }
+
+ nsecRRs := extractNSECs(authoritySection)
+ if len(nsecRRs) > 0 {
+ return nsecValidateNoDataDS(qname, nsecRRs)
+ }
+
+ return fmt.Errorf("verifying no DS data for %s: %w: "+
+ "no NSEC or NSEC3 record found",
+ qname, ErrBogus)
+}
+
+// See https://datatracker.ietf.org/doc/html/rfc5155#section-8.6
+func verifyNoDataNsecxTypesDS(nsecVariant string,
+ nsecTypes []uint16) (err error) {
+ for _, nsecType := range nsecTypes {
+ switch nsecType {
+ case dns.TypeSOA:
+ return fmt.Errorf("%w: %s contains SOA type"+
+ " so is from the child zone and not the parent zone",
+ ErrBogus, nsecVariant)
+ case dns.TypeDS:
+ return fmt.Errorf("%w: %s contains DS type", ErrBogus, nsecVariant)
+ case dns.TypeCNAME:
+ return fmt.Errorf("%w: %s contains CNAME type",
+ ErrBogus, nsecVariant)
+ }
+ }
+
+ return nil
+}
diff --git a/internal/dnssec/nsec.go b/internal/dnssec/nsec.go
new file mode 100644
index 00000000..e0fb1b52
--- /dev/null
+++ b/internal/dnssec/nsec.go
@@ -0,0 +1,158 @@
+package dnssec
+
+import (
+ "fmt"
+ "strings"
+
+ "github.com/miekg/dns"
+)
+
+func mustRRToNSEC(rr dns.RR) (nsec *dns.NSEC) {
+ nsec, ok := rr.(*dns.NSEC)
+ if !ok {
+ panic(fmt.Sprintf("RR is of type %T and not of type *dns.NSEC", rr))
+ }
+ return nsec
+}
+
+// extractNSECs returns the NSEC RRs found in the NSEC
+// signed RRSet from the slice of signed RRSets.
+func extractNSECs(rrSets []dnssecRRSet) (nsecs []dns.RR) {
+ for _, rrSet := range rrSets {
+ if rrSet.qtype() == dns.TypeNSEC {
+ return rrSet.rrSet
+ }
+ }
+ return nil
+}
+
+func nsecValidateNxDomain(qname string, nsecRRSet []dns.RR) (err error) {
+ for _, nsecRR := range nsecRRSet {
+ nsec := mustRRToNSEC(nsecRR)
+ if nsecCoversZone(qname, nsec.Hdr.Name, nsec.NextDomain) {
+ return nil
+ }
+ }
+
+ return fmt.Errorf("for qname %s: %w: "+
+ "no NSEC covering qname found",
+ qname, ErrBogus)
+}
+
+func nsecValidateNoData(qname string, qType uint16,
+ nsecRRSet []dns.RR) (err error) {
+ if qType == dns.TypeDS {
+ return nsecValidateNoDataDS(qname, nsecRRSet)
+ }
+
+ var qnameMatchingNSEC *dns.NSEC
+ for _, nsecRR := range nsecRRSet {
+ nsec := mustRRToNSEC(nsecRR)
+ if nsecMatchesQname(nsec, qname) {
+ qnameMatchingNSEC = nsec
+ break
+ }
+ }
+
+ if qnameMatchingNSEC == nil {
+ return fmt.Errorf("for zone %s and type %s: %w: "+
+ "no NSEC matching qname found",
+ qname, dns.TypeToString[qType], ErrBogus)
+ }
+
+ for _, nsecType := range qnameMatchingNSEC.TypeBitMap {
+ switch nsecType {
+ case qType:
+ return fmt.Errorf("for qname %s and type %s: %w: "+
+ "qtype contained in NSEC",
+ qname, dns.TypeToString[qType], ErrBogus)
+ case dns.TypeCNAME: // TODO check this is invalid
+ return fmt.Errorf("for qname %s and type %s: %w: "+
+ "CNAME contained in NSEC",
+ qname, dns.TypeToString[qType], ErrBogus)
+ }
+ }
+
+ return nil
+}
+
+func nsecValidateNoDataDS(qname string, nsecRRSet []dns.RR) (err error) {
+ var qnameMatchingNSEC *dns.NSEC
+ for _, nsecRR := range nsecRRSet {
+ nsec := mustRRToNSEC(nsecRR)
+ if nsecMatchesQname(nsec, qname) {
+ qnameMatchingNSEC = nsec
+ break
+ }
+ }
+
+ if qnameMatchingNSEC == nil {
+ return fmt.Errorf("for qname %s: %w: "+
+ "no NSEC matching qname found",
+ qname, ErrBogus)
+ }
+
+ err = verifyNoDataNsecxTypesDS("NSEC", qnameMatchingNSEC.TypeBitMap)
+ if err != nil {
+ return fmt.Errorf("for qname %s: %w",
+ qname, err)
+ }
+ return nil
+}
+
+// nsecMatchesQname returns true if the NSEC owner name is equal
+// to the qname or if the NSEC owner name is a wildcard name parent
+// of qname.
+func nsecMatchesQname(nsec *dns.NSEC, qname string) bool {
+ return nsec.Hdr.Name == qname || (strings.HasPrefix(nsec.Hdr.Name, "*.") &&
+ dns.IsSubDomain(nsec.Hdr.Name[2:], qname))
+}
+
+// nsecCoversZone returns true if the zone is within the OPEN interval
+// delimited by the nsecOwner and the nsecNext FQDNs given.
+// TODO improve inspiring from
+// https://github.com/NLnetLabs/unbound/blob/master/util/data/dname.c#L802
+func nsecCoversZone(zone, nsecOwner, nsecNext string) (ok bool) {
+ if zone == nsecOwner || zone == nsecNext {
+ return false
+ }
+
+ zoneLabels := dns.SplitDomainName(zone)
+ nsecOwnerLabels := dns.SplitDomainName(nsecOwner)
+
+ if len(zoneLabels) < len(nsecOwnerLabels) {
+ // zone is shorter than NSEC owner, so it cannot be covered
+ return false
+ }
+
+ for i := range nsecOwnerLabels {
+ zoneLabel := zoneLabels[len(zoneLabels)-1-i]
+ nsecOwnerLabel := nsecOwnerLabels[len(nsecOwnerLabels)-1-i]
+ if nsecOwnerLabel == "*" {
+ // wildcard NSEC owner containing zone
+ return true
+ } else if zoneLabel < nsecOwnerLabel {
+ return false
+ }
+ }
+
+ nsecNextLabels := dns.SplitDomainName(nsecNext)
+ if len(zoneLabels) < len(nsecNextLabels) {
+ // zone is shorter than NSEC next, so it cannot be covered
+ return false
+ }
+
+ minLabelsCount := min(len(zoneLabels), len(nsecNextLabels))
+ for i := 0; i < minLabelsCount; i++ {
+ zoneLabel := zoneLabels[len(zoneLabels)-1-i]
+ nsecNextLabel := nsecNextLabels[len(nsecNextLabels)-1-i]
+ if zoneLabel > nsecNextLabel {
+ return false
+ }
+ }
+
+ // Zone and next domain have the same labels for the first
+ // minLabelsCount labels, and zone != next, so zone is within
+ // the interval delimited by owner and next.
+ return true
+}
diff --git a/internal/dnssec/nsec3.go b/internal/dnssec/nsec3.go
new file mode 100644
index 00000000..10d25e6c
--- /dev/null
+++ b/internal/dnssec/nsec3.go
@@ -0,0 +1,427 @@
+package dnssec
+
+import (
+ "errors"
+ "fmt"
+ "strings"
+
+ "github.com/miekg/dns"
+ "golang.org/x/exp/maps"
+)
+
+func mustRRToNSEC3(rr dns.RR) (nsec3 *dns.NSEC3) {
+ nsec3, ok := rr.(*dns.NSEC3)
+ if !ok {
+ panic(fmt.Sprintf("RR is of type %T and not of type *dns.NSEC3", rr))
+ }
+ return nsec3
+}
+
+// extractNSEC3s returns the NSEC3 RRs found in the NSEC3
+// signed RRSet from the slice of signed RRSets. It also returns
+// wildcard as true if one of the NSEC3 RRSets RRSigs is for a wildcard.
+func extractNSEC3s(rrSets []dnssecRRSet) (
+ rrs []dns.RR, wildcard bool) {
+ rrs = make([]dns.RR, 0, len(rrSets))
+ for _, rrSet := range rrSets {
+ if rrSet.qtype() != dns.TypeNSEC3 {
+ continue
+ }
+ rrs = append(rrs, rrSet.rrSet...)
+
+ if !wildcard {
+ for _, rrSig := range rrSet.rrSigs {
+ if isRRSigForWildcard(rrSig) {
+ wildcard = true
+ break
+ }
+ }
+ }
+ }
+ return rrs, wildcard
+}
+
+var (
+ ErrNSEC3RRSetDifferentHashTypes = errors.New("NSEC3 RRSet contains different hash types")
+ ErrNSEC3RRSetDifferentIterations = errors.New("NSEC3 RRSet contains different iterations")
+ ErrNSEC3RRSetDifferentSalts = errors.New("NSEC3 RRSet contains different salts")
+)
+
+func nsec3InitialChecks(nsec3RRSet []dns.RR) (sanitizedNSEC3RRSet []dns.RR, err error) {
+ sanitizedNSEC3RRSet = make([]dns.RR, 0, len(nsec3RRSet))
+
+ const usualCapacity = 1
+ hashTypes := make(map[uint8]struct{}, usualCapacity)
+ iterations := make(map[uint16]struct{}, usualCapacity)
+ salts := make(map[string]struct{}, usualCapacity)
+
+ for _, nsec3RR := range nsec3RRSet {
+ nsec3 := mustRRToNSEC3(nsec3RR)
+
+ // Only accept supported hash type
+ // https://datatracker.ietf.org/doc/html/rfc5155#section-8.1
+ if !isOneOf(nsec3.Hash, dns.SHA1) {
+ continue
+ }
+
+ // Flag field must be zero or one (opt-out).
+ // https://datatracker.ietf.org/doc/html/rfc5155#section-8.2
+ if !isOneOf(nsec3.Flags, 0, 1) {
+ continue
+ }
+
+ // Track hash algorithms, iterations and salts
+ // https://datatracker.ietf.org/doc/html/rfc5155#section-8.2
+ hashTypes[nsec3.Hash] = struct{}{}
+ iterations[nsec3.Iterations] = struct{}{}
+ salts[nsec3.Salt] = struct{}{}
+
+ sanitizedNSEC3RRSet = append(sanitizedNSEC3RRSet, nsec3RR)
+ }
+
+ // Verify all NSEC3 RRSet RRs have the same hash type, iterations and salt
+ // If not, the response may be considered as bogus, so we return an error.
+ // https://datatracker.ietf.org/doc/html/rfc5155#section-8.2
+ switch {
+ case len(hashTypes) > 1:
+ return nil, fmt.Errorf("%w: %s", ErrNSEC3RRSetDifferentHashTypes,
+ hashesToString(maps.Keys(hashTypes)))
+ case len(iterations) > 1:
+ return nil, fmt.Errorf("%w: %s", ErrNSEC3RRSetDifferentIterations,
+ integersToString(maps.Keys(iterations)))
+ case len(salts) > 1:
+ return nil, fmt.Errorf("%w: %s", ErrNSEC3RRSetDifferentSalts,
+ strings.Join(maps.Keys(salts), ", "))
+ }
+
+ return sanitizedNSEC3RRSet, nil
+}
+
+func nsec3ValidateNxDomain(qname string, nsec3RRSet []dns.RR) (err error) {
+ // Proof qname does not exist with the closest encloser proof
+ closestEncloser, err := nsec3VerifyClosestEncloserProof(qname, nsec3RRSet)
+ if err != nil {
+ return fmt.Errorf("for qname %s: "+
+ "validating closest encloser proof: %w",
+ qname, err)
+ }
+
+ // Proof the wildcard matching qname does not exist
+ wildcardName := "*." + closestEncloser
+ wildcardCoveringNSEC3 := nsec3FindCovering(wildcardName, nsec3RRSet)
+ if wildcardCoveringNSEC3 == nil {
+ return fmt.Errorf("for qname %s: %w: "+
+ "NSEC3 matching wildcard %s not found",
+ qname, ErrBogus, wildcardName)
+ }
+
+ return nil
+}
+
+// nsec3ValidateNoData validates a no data response for a given QTYPE.
+// See https://datatracker.ietf.org/doc/html/rfc5155#section-8.5
+// and https://datatracker.ietf.org/doc/html/rfc5155#section-8.6
+func nsec3ValidateNoData(qname string, qType uint16,
+ nsec3RRSet []dns.RR) (err error) {
+ if qType == dns.TypeDS {
+ return nsec3ValidateNoDataDS(qname, nsec3RRSet)
+ }
+
+ err = nsec3RRSetHasMatchingWithoutTypes(nsec3RRSet,
+ qname, qType, dns.TypeCNAME)
+ if err != nil {
+ return fmt.Errorf("for qname %s: %w", qname, err)
+ }
+ return nil
+}
+
+// nsec3ValidateNoDataDS is used internally in nsec3VerifyNoData.
+// See https://datatracker.ietf.org/doc/html/rfc5155#section-8.6
+func nsec3ValidateNoDataDS(qname string, nsec3RRSet []dns.RR) (err error) {
+ qnameMatchingNSEC3 := nsec3FindMatching(qname, nsec3RRSet)
+ if qnameMatchingNSEC3 != nil {
+ err = verifyNoDataNsecxTypesDS("NSEC3", qnameMatchingNSEC3.TypeBitMap)
+ if err != nil {
+ return fmt.Errorf("for qname %s: %w", qname, err)
+ }
+ return nil
+ }
+
+ // No matching NSEC3 found, first verify the closest encloser proof
+ // for qname exists.
+ closestEncloser, err := nsec3VerifyClosestEncloserProof(qname, nsec3RRSet)
+ if err != nil {
+ return fmt.Errorf("for qname %s: "+
+ "validating closest encloser proof: %w",
+ qname, err)
+ }
+ nextCloser := getNextCloser(qname, closestEncloser)
+
+ // Verify the NSEC3 covering the next closer name has the Opt-Out bit set.
+ nextCloserCoveringNSEC3 := nsec3FindCovering(nextCloser, nsec3RRSet)
+ if nextCloserCoveringNSEC3 == nil {
+ return fmt.Errorf("for qname %s: %w: "+
+ "no NSEC3 covers next closer %s",
+ qname, ErrBogus, nextCloser)
+ }
+
+ optOutBitSet := nextCloserCoveringNSEC3.Flags == 1
+ if !optOutBitSet {
+ return fmt.Errorf("for qname %s: %w: "+
+ "NSEC3 covering next closer %s Opt-Out bit %d is not set",
+ qname, ErrBogus, nextCloser, nextCloserCoveringNSEC3.Flags)
+ }
+
+ return nil
+}
+
+// See https://datatracker.ietf.org/doc/html/rfc5155#section-8.7
+func nsec3ValidateNoDataWildcard(qname string, qType uint16,
+ nsec3RRSet []dns.RR) (err error) {
+ // Proof qname does not exist with the closest encloser proof
+ closestEncloser, err := nsec3VerifyClosestEncloserProof(qname, nsec3RRSet)
+ if err != nil {
+ return fmt.Errorf("for qname %s: "+
+ "validating closest encloser proof: %w",
+ qname, err)
+ }
+
+ // Proof the wildcard matching qname exists
+ wildcardName := "*." + closestEncloser
+ err = nsec3RRSetHasMatchingWithoutTypes(nsec3RRSet,
+ wildcardName, qType, dns.TypeCNAME)
+ if err != nil {
+ return fmt.Errorf("for qname %s: %w", qname, err)
+ }
+
+ return nil
+}
+
+// See https://datatracker.ietf.org/doc/html/rfc5155#section-8.8
+func nsec3ValidateWildcard(qname string, nsec3RRSet []dns.RR) (err error) {
+ candidateClosestEncloser, err := nsec3VerifyClosestEncloserProof(qname, nsec3RRSet)
+ if err != nil {
+ return fmt.Errorf("for qname %s: "+
+ "validating closest encloser proof: %w",
+ qname, err)
+ }
+ // This closest encloser is the immediate ancestor to the
+ // generating wildcard.
+
+ // Validators MUST verify that there is an NSEC3 RR that covers the
+ // "next closer" name to QNAME present in the response. This proves
+ // that QNAME itself did not exist and that the correct wildcard was
+ // used to generate the response.
+ nextCloser := getNextCloser(qname, candidateClosestEncloser)
+ nextCloserCoveringNSEC3 := nsec3FindCovering(nextCloser, nsec3RRSet)
+ if nextCloserCoveringNSEC3 != nil {
+ return nil
+ }
+
+ return fmt.Errorf("for qname %s: %w: "+
+ "no NSEC3 covers next closer %s",
+ qname, ErrBogus, nextCloser)
+}
+
+// The delegationName argument is the owner name of the NS RRSet in the
+// authority section of the response.
+// See https://datatracker.ietf.org/doc/html/rfc5155#section-8.9
+//
+//nolint:unused
+func nsec3ValidateReferralsToUnsignedSubzones(qname, delegationName string,
+ nsec3RRSet []dns.RR) (err error) {
+ matchingNSEC3 := nsec3FindMatching(qname, nsec3RRSet)
+ if matchingNSEC3 != nil {
+ var hasNS bool
+ for _, nsec3Type := range matchingNSEC3.TypeBitMap {
+ switch nsec3Type {
+ case dns.TypeNS:
+ // This implies the absence of a DNAME type
+ hasNS = true
+ case dns.TypeDS:
+ return fmt.Errorf("for qname %s and delegation name %s: %w: "+
+ "NSEC3 matching the delegation name contains DS type",
+ qname, delegationName, ErrBogus)
+ case dns.TypeSOA:
+ return fmt.Errorf("for qname %s and delegation name %s: %w: "+
+ "NSEC3 matching the delegation name contains SOA type",
+ qname, delegationName, ErrBogus)
+ }
+ }
+
+ if !hasNS {
+ return fmt.Errorf("for qname %s and delegation name %s: %w: "+
+ "NSEC3 matching the delegation name does not contain NS type",
+ qname, delegationName, ErrBogus)
+ }
+
+ return nil
+ }
+
+ // No NSEC3 matching the delegation name found
+ closestEncloser, err := nsec3VerifyClosestEncloserProof(
+ delegationName, nsec3RRSet)
+ if err != nil {
+ return fmt.Errorf("for qname %s and delegation name %s: "+
+ "validating closest encloser proof: %w",
+ qname, delegationName, err)
+ }
+
+ nextCloser := getNextCloser(delegationName, closestEncloser)
+ nextCloserCoveringNSEC3 := nsec3FindCovering(nextCloser, nsec3RRSet)
+ if nextCloserCoveringNSEC3 == nil {
+ return fmt.Errorf("for qname %s and delegation name %s: %w: "+
+ "no NSEC3 covers next closer %s",
+ qname, delegationName, ErrBogus, nextCloser)
+ }
+
+ optOutBitSet := nextCloserCoveringNSEC3.Flags == 1
+ if optOutBitSet {
+ return nil
+ }
+ return fmt.Errorf("for qname %s and delegation name %s: %w: "+
+ "NSEC3 covering next closer %s Opt-Out bit %d is not set",
+ qname, delegationName, ErrBogus, nextCloser, nextCloserCoveringNSEC3.Flags)
+}
+
+// nsec3VerifyClosestEncloserProof validates a closest encloser proof,
+// and returns the closest encloser name if the proof is valid.
+// If the proof is not valid, an error is returned.
+// For such proof to be valid, the longest name X must be found such that:
+// - X is an ancestor of qname that is matched by an NSEC3 RR
+// - the name one label longer than X (ancestor of qname or equal to qname)
+// is covered by an NSEC3 RR.
+//
+// See https://datatracker.ietf.org/doc/html/rfc5155#section-8.3
+// The implementation is based on the pseudo code from the RFC.
+//
+//nolint:cyclop
+func nsec3VerifyClosestEncloserProof(qname string, nsec3RRSet []dns.RR) (
+ closestEncloser string, err error) {
+ sname := qname
+
+ for {
+ var matchingNSEC3 *dns.NSEC3
+ snameCovered := false
+ for _, nsec3RR := range nsec3RRSet {
+ nsec3 := mustRRToNSEC3(nsec3RR)
+
+ if nsec3.Cover(sname) {
+ snameCovered = true
+ }
+
+ if nsec3.Match(sname) {
+ matchingNSEC3 = nsec3
+ }
+ }
+
+ if matchingNSEC3 != nil {
+ if !snameCovered {
+ return "", fmt.Errorf("%w: sname %s matched but not covered",
+ ErrBogus, sname)
+ }
+ closestEncloser = sname
+
+ // The DNAME type bit must not be set and the NS type bit may
+ // only be set if the SOA type bit is set.
+ // If this is not the case, it would be an indication that an attacker
+ // is using them to falsely deny the existence of RRs for which the
+ // server is not authoritative.
+ var hasNS, hasSOA bool
+ for _, nsec3Type := range matchingNSEC3.TypeBitMap {
+ switch nsec3Type {
+ case dns.TypeDNAME:
+ return "", fmt.Errorf("%w: NSEC3 of closest encloser %s "+
+ "contains the DNAME type", ErrBogus, sname)
+ case dns.TypeNS:
+ hasNS = true
+ case dns.TypeSOA:
+ hasSOA = true
+ }
+ }
+ if hasNS && !hasSOA {
+ return "", fmt.Errorf("%w: NSEC3 of closest encloser %s "+
+ "contains the NS type but not the SOA type", ErrBogus, sname)
+ }
+
+ return closestEncloser, nil
+ }
+
+ const offset = 0
+ i, end := dns.NextLabel(sname, offset)
+ if end {
+ return "", fmt.Errorf("%w: sname reached the last label already", ErrBogus)
+ }
+ sname = sname[i:]
+ }
+}
+
+// getNextCloser returns the "next closer" name of qname given a closest
+// encloser name.
+// For example with qname="a.b.example.com." and closestEncloser=".com.",
+// then nextCloser="example.com.".
+func getNextCloser(qname, closestEncloser string) (nextCloser string) {
+ closestEncloserLabelsCount := dns.CountLabel(closestEncloser)
+ qnameLabelsCount := dns.CountLabel(qname)
+
+ // Double check the qname is two labels longer than the closest encloser.
+ // TODO eventual remove check
+ if qnameLabelsCount < closestEncloserLabelsCount+1 {
+ panic(fmt.Sprintf("qname %s is not at least one label longer than closest encloser %s",
+ qname, closestEncloser))
+ }
+
+ nextCloserStartIndex, startOvershoot := dns.PrevLabel(qname, closestEncloserLabelsCount+1)
+ if startOvershoot {
+ panic("start overshoot should not happen")
+ }
+ nextCloser = qname[nextCloserStartIndex:]
+
+ return nextCloser
+}
+
+// nsec3RRSetHasMatchingWithoutTypes returns an error if:
+// - there is no NSEC3 matching matchName
+// - the NSEC3 matching matchName contains one of the notTypes.
+func nsec3RRSetHasMatchingWithoutTypes(nsec3RRSet []dns.RR,
+ matchName string, notTypes ...uint16) (err error) {
+ matchingNSEC3 := nsec3FindMatching(matchName, nsec3RRSet)
+ if matchingNSEC3 == nil {
+ return fmt.Errorf("%w: no NSEC3 matching %s",
+ ErrBogus, matchName)
+ }
+
+ for _, nsec3Type := range matchingNSEC3.TypeBitMap {
+ for _, notType := range notTypes {
+ if nsec3Type != notType {
+ continue
+ }
+ return fmt.Errorf("%w: NSEC3 matching %s contains type %s",
+ ErrBogus, matchName, dns.TypeToString[notType])
+ }
+ }
+
+ return nil
+}
+
+func nsec3FindMatching(qname string, nsec3RRSet []dns.RR) (
+ matchingNSEC3 *dns.NSEC3) {
+ for _, nsec3RR := range nsec3RRSet {
+ nsec3 := mustRRToNSEC3(nsec3RR)
+ if nsec3.Match(qname) {
+ return nsec3
+ }
+ }
+ return nil
+}
+
+func nsec3FindCovering(qname string, nsec3RRSet []dns.RR) (
+ coveringNSEC3 *dns.NSEC3) {
+ for _, nsec3RR := range nsec3RRSet {
+ nsec3 := mustRRToNSEC3(nsec3RR)
+ if nsec3.Cover(qname) {
+ return nsec3
+ }
+ }
+ return nil
+}
diff --git a/internal/dnssec/nsec3_test.go b/internal/dnssec/nsec3_test.go
new file mode 100644
index 00000000..5dadaf08
--- /dev/null
+++ b/internal/dnssec/nsec3_test.go
@@ -0,0 +1,39 @@
+package dnssec
+
+import (
+ "testing"
+
+ "github.com/stretchr/testify/assert"
+)
+
+func Test_getNextCloser(t *testing.T) {
+ t.Parallel()
+
+ testCases := map[string]struct {
+ qname string
+ closestEncloser string
+ nextCloser string
+ }{
+ "case1": {
+ qname: "a.b.example.com.",
+ closestEncloser: "example.com.",
+ nextCloser: "b.example.com.",
+ },
+ "q_name_is_next_closer": {
+ qname: "a.example.com.",
+ closestEncloser: "example.com.",
+ nextCloser: "a.example.com.",
+ },
+ }
+
+ for name, testCase := range testCases {
+ testCase := testCase
+ t.Run(name, func(t *testing.T) {
+ t.Parallel()
+
+ nextCloser := getNextCloser(testCase.qname, testCase.closestEncloser)
+
+ assert.Equal(t, testCase.nextCloser, nextCloser)
+ })
+ }
+}
diff --git a/internal/dnssec/nsec_test.go b/internal/dnssec/nsec_test.go
new file mode 100644
index 00000000..741f0ec9
--- /dev/null
+++ b/internal/dnssec/nsec_test.go
@@ -0,0 +1,81 @@
+package dnssec
+
+import (
+ "testing"
+
+ "github.com/stretchr/testify/assert"
+)
+
+func Test_nsecCover(t *testing.T) {
+ t.Parallel()
+
+ testCases := map[string]struct {
+ zone string
+ nsecOwner string
+ nsecNext string
+ ok bool
+ }{
+ "zone_shorter_than_owner": {
+ zone: "example.com.",
+ nsecOwner: "a.example.com.",
+ },
+ "zone_before_owner": {
+ zone: "a.example.com.",
+ nsecOwner: "b.example.com.",
+ },
+ "zone_not_subdomain_of_owner": {
+ zone: "a.a.example.com.",
+ nsecOwner: "b.example.com.",
+ },
+ "malformed_longer_next": {
+ zone: "b.example.com.",
+ nsecOwner: "a.example.com.",
+ nsecNext: "c.c.example.com.",
+ },
+ "zone_equal_to_next": {
+ zone: "b.example.com.",
+ nsecOwner: "a.example.com.",
+ nsecNext: "b.example.com.",
+ },
+ "zone_after_next": {
+ zone: "c.example.com.",
+ nsecOwner: "a.example.com.",
+ nsecNext: "b.example.com.",
+ },
+ "zone_not_subdomain_of_next": {
+ zone: "b.b.example.com.",
+ nsecOwner: "a.example.com.",
+ nsecNext: "c.example.com.",
+ ok: true,
+ },
+ "wildcard_a": {
+ zone: "a.example.com.",
+ nsecOwner: "*.example.com.",
+ nsecNext: "example.com.",
+ ok: true,
+ },
+ "wildcard_a.a": {
+ zone: "a.a.example.com.",
+ nsecOwner: "*.example.com.",
+ nsecNext: "example.com.",
+ ok: true,
+ },
+ "wildcard_#": {
+ zone: "#.example.com.",
+ nsecOwner: "*.example.com.",
+ nsecNext: "example.com.",
+ ok: true,
+ },
+ }
+
+ for name, testCase := range testCases {
+ testCase := testCase
+ t.Run(name, func(t *testing.T) {
+ t.Parallel()
+
+ ok := nsecCoversZone(testCase.zone, testCase.nsecOwner, testCase.nsecNext)
+
+ assert.Equal(t, testCase.ok, ok)
+ })
+ }
+}
diff --git a/internal/dnssec/nxdomain.go b/internal/dnssec/nxdomain.go
new file mode 100644
index 00000000..cef6b65b
--- /dev/null
+++ b/internal/dnssec/nxdomain.go
@@ -0,0 +1,40 @@
+package dnssec
+
+import (
+ "errors"
+ "fmt"
+
+ "github.com/miekg/dns"
+)
+
+var (
+ ErrRRSigWildcardUnexpected = errors.New("RRSIG for a wildcard is unexpected")
+)
+
+func validateNxDomain(qname string, authoritySection []dnssecRRSet,
+ keyTagToDNSKey map[uint16]*dns.DNSKEY) (err error) {
+ err = verifyRRSetsRRSig(authoritySection, keyTagToDNSKey)
+ if err != nil {
+ return fmt.Errorf("verifying RRSIGs: %w", err)
+ }
+
+ nsec3RRs, wildcard := extractNSEC3s(authoritySection)
+ if wildcard {
+ return fmt.Errorf("for NXDOMAIN response for %s: NSEC3: %w",
+ qname, ErrRRSigWildcardUnexpected)
+ } else if len(nsec3RRs) > 0 {
+ nsec3RRs, err = nsec3InitialChecks(nsec3RRs)
+ if err != nil {
+ return fmt.Errorf("initial NSEC3 checks: %w", err)
+ }
+ return nsec3ValidateNxDomain(qname, nsec3RRs)
+ }
+
+ nsecRRs := extractNSECs(authoritySection)
+ if len(nsecRRs) > 0 {
+ return nsecValidateNxDomain(qname, nsecRRs)
+ }
+
+ return fmt.Errorf("for %s: %w: no NSEC or NSEC3 record found",
+ qname, ErrBogus)
+}
diff --git a/internal/dnssec/query.go b/internal/dnssec/query.go
new file mode 100644
index 00000000..d0178bb0
--- /dev/null
+++ b/internal/dnssec/query.go
@@ -0,0 +1,195 @@
+package dnssec
+
+import (
+ "errors"
+ "fmt"
+
+ "github.com/miekg/dns"
+ "github.com/qdm12/dns/v2/internal/stateful"
+)
+
+var (
+ ErrRcodeBad = errors.New("bad response rcode")
+)
+
+func queryRRSets(handler dns.Handler, zone string,
+ qClass, qType uint16) (response dnssecResponse, err error) {
+ request := newEDNSRequest(zone, qClass, qType)
+
+ statefulWriter := stateful.NewWriter()
+ handler.ServeDNS(statefulWriter, request)
+ dnsResponse := statefulWriter.Response
+ response.rcode = dnsResponse.Rcode
+
+ switch {
+ case dnsResponse.Rcode == dns.RcodeSuccess && len(dnsResponse.Answer) > 0:
+ // Success and we have at least one answer RR.
+ response.answerRRSets, err = groupRRs(dnsResponse.Answer)
+ if err != nil {
+ return dnssecResponse{}, fmt.Errorf(
+ "grouping answer RRSets for %s: %w",
+ nameClassTypeToString(zone, qClass, qType), err)
+ }
+
+ if !response.isSigned() {
+ // We have all unsigned answers
+ return response, nil
+ }
+
+ // Every RRSet has at least one RRSIG associated with it.
+ // The caller should then verify the RRSIGs and MAY need
+ // NSEC or NSEC3 RRSets from the authority section to verify
+ // it does not match a wildcard.
+ response.authorityRRSets, err = groupRRs(dnsResponse.Ns)
+ if err != nil {
+ return dnssecResponse{}, fmt.Errorf(
+ "grouping authority RRSets for %s: %w",
+ nameClassTypeToString(zone, qClass, qType), err)
+ }
+
+ return response, nil
+ case dnsResponse.Rcode == dns.RcodeSuccess && len(dnsResponse.Answer) == 0,
+ dnsResponse.Rcode == dns.RcodeNameError:
+ // NXDOMAIN or NODATA response, we need to verify the negative
+ // response with the query authority section NSEC/NSEC3 RRSet
+ // or verify the zone is insecure.
+ // If the zone is insecure, the caller verifies the zone is
+ // insecure using the NSEC/NSEC3 records of the authority
+ // section of the DS query for that zone, or any first of
+ // its parent zone with an NSEC/NSEC3 record for that zone,
+ // walking towards the root zone.
+ // There is no difference in handling if we received a NODATA
+ // or NXDOMAIN response.
+
+ if len(dnsResponse.Ns) == 0 {
+ // No authority RR so there cannot be any NSEC/NSEC3 RRSet,
+ // the zone is thus insecure.
+ return response, nil
+ }
+
+ response.authorityRRSets, err = groupRRs(dnsResponse.Ns)
+ if err != nil {
+ return dnssecResponse{}, fmt.Errorf(
+ "grouping authority RRSets for %s: %w",
+ nameClassTypeToString(zone, qClass, qType), err)
+ }
+
+ // TODO make sure we ignore nsec without rrsig
+ return response, nil
+ default: // other error
+ // If the response Rcode is dns.RcodeServerFailure,
+ // this may mean DNSSEC validation failed on the upstream server.
+ // https://www.ietf.org/rfc/rfc4033.txt
+ // This specification only defines how security-aware name servers can
+ // signal non-validating stub resolvers that data was found to be bogus
+ // (using RCODE=2, "Server Failure"; see [RFC4035]).
+ return dnssecResponse{}, fmt.Errorf(
+ "for %s: %w: %s",
+ nameClassTypeToString(zone, qClass, qType),
+ ErrRcodeBad, dns.RcodeToString[dnsResponse.Rcode])
+ }
+}
+
+var (
+ ErrRRSetSignedAndUnsigned = errors.New("mix of signed and unsigned RRSets")
+ ErrRRSigForNoRRSet = errors.New("RRSIG for no RRSet")
+)
+
+// groupRRs groups RRs by type AND owner AND class, returning a slice
+// of 'DNSSEC RRSets' where each contains at least one RR,
+// and zero or one RRSIG signature.
+// Regarding the RRSig validity requirements listed in
+// https://datatracker.ietf.org/doc/html/rfc4035#section-5.3.1
+//
+// The following requirements are fullfiled by design:
+// - The RRSIG RR and the RRset MUST have the same owner name and the
+// same class
+// - The RRSIG RR's Type Covered field MUST equal the RRset's type.
+//
+// And the function returns an error for the following unmet requirements:
+// - The RRSIG RR's Signer's Name field MUST be the name of the zone
+// that contains the RRset.
+// - The number of labels in the RRset owner name MUST be greater than
+// or equal to the value in the RRSIG RR's Labels field.
+//
+// The following requirements are enforced at a later stage:
+// - The validator's notion of the current time MUST be less than or
+// equal to the time listed in the RRSIG RR's Expiration field.
+func groupRRs(rrs []dns.RR) (dnssecRRSets []dnssecRRSet, err error) {
+ // For well formed DNSSEC DNS answers, there should be at most
+ // N/2 signed RRSets (grouped by qname-qtype-qclass) where N is
+ // the number of total answers.
+ maxRRSets := len(rrs) // all unsigned RRs
+ dnssecRRSets = make([]dnssecRRSet, 0, maxRRSets)
+ type typeZoneKey struct {
+ rrType uint16
+ owner string
+ class uint16
+ }
+ typeZoneToIndex := make(map[typeZoneKey]int, maxRRSets)
+
+ // Used to check we have all signed RRSets or all
+ // unsigned RRSets.
+ signedRRSetsCount := 0
+ for _, rr := range rrs {
+ header := rr.Header()
+ typeZoneKey := typeZoneKey{
+ owner: header.Name,
+ class: header.Class,
+ }
+
+ rrType := header.Rrtype
+ if rrType == dns.TypeRRSIG {
+ rrsig := mustRRToRRSig(rr)
+ err = rrsigInitialChecks(rrsig)
+ if err != nil {
+ return nil, err
+ }
+
+ typeZoneKey.rrType = rrsig.TypeCovered
+ i, ok := typeZoneToIndex[typeZoneKey]
+ if !ok {
+ dnssecRRSets = append(dnssecRRSets, dnssecRRSet{})
+ i = len(dnssecRRSets) - 1
+ typeZoneToIndex[typeZoneKey] = i
+ }
+
+ if len(dnssecRRSets[i].rrSigs) == 0 {
+ signedRRSetsCount++
+ }
+ dnssecRRSets[i].rrSigs = append(dnssecRRSets[i].rrSigs, rrsig)
+ continue
+ }
+
+ typeZoneKey.rrType = rrType
+ i, ok := typeZoneToIndex[typeZoneKey]
+ if !ok {
+ dnssecRRSets = append(dnssecRRSets, dnssecRRSet{})
+ i = len(dnssecRRSets) - 1
+ typeZoneToIndex[typeZoneKey] = i
+ }
+ dnssecRRSets[i].rrSet = append(dnssecRRSets[i].rrSet, rr)
+ }
+
+ // Verify all RRSets are either signed or unsigned.
+ switch signedRRSetsCount {
+ case 0:
+ case len(dnssecRRSets):
+ default:
+ unsignedRRSetsCount := len(dnssecRRSets) - signedRRSetsCount
+ return nil, fmt.Errorf("%w: %d signed and %d unsigned RRSets",
+ ErrRRSetSignedAndUnsigned, signedRRSetsCount, unsignedRRSetsCount)
+ }
+
+ // Verify built DNSSEC RRSets are well formed.
+ for _, dnssecRRSet := range dnssecRRSets {
+ if len(dnssecRRSet.rrSigs) > 0 && len(dnssecRRSet.rrSet) == 0 {
+ return nil, fmt.Errorf("for RRSet %s %s: %w",
+ dnssecRRSet.rrSigs[0].Hdr.Name,
+ dns.TypeToString[dnssecRRSet.rrSigs[0].TypeCovered],
+ ErrRRSigForNoRRSet)
+ }
+ }
+
+ return dnssecRRSets, nil
+}
diff --git a/internal/dnssec/query_test.go b/internal/dnssec/query_test.go
new file mode 100644
index 00000000..6358f03c
--- /dev/null
+++ b/internal/dnssec/query_test.go
@@ -0,0 +1,150 @@
+package dnssec
+
+import (
+ "testing"
+
+ "github.com/miekg/dns"
+ "github.com/stretchr/testify/assert"
+)
+
+func Test_groupRRs(t *testing.T) {
+ t.Parallel()
+
+ testCases := map[string]struct {
+ rrs []dns.RR
+ dnssecRRSets []dnssecRRSet
+ errWrapped error
+ errMessage string
+ }{
+ "no_rrs": {
+ dnssecRRSets: []dnssecRRSet{},
+ },
+ "bad_single_rrsig_answer": {
+ rrs: []dns.RR{
+ newEmptyRRSig(dns.TypeA),
+ },
+ errWrapped: ErrRRSigForNoRRSet,
+ errMessage: "for RRSet example.com. A: RRSIG for no RRSet",
+ },
+ "bad_rrsig_for_no_rrset": {
+ rrs: []dns.RR{
+ newEmptyAAAA(),
+ newEmptyRRSig(dns.TypeAAAA),
+ newEmptyRRSig(dns.TypeA), // bad one
+ },
+ errWrapped: ErrRRSigForNoRRSet,
+ errMessage: "for RRSet example.com. A: RRSIG for no RRSet",
+ },
+ "multiple_rrsig_for_same_type": {
+ rrs: []dns.RR{
+ newEmptyRRSig(dns.TypeA),
+ newEmptyA(),
+ newEmptyRRSig(dns.TypeA),
+ },
+ dnssecRRSets: []dnssecRRSet{
+ {
+ rrSet: []dns.RR{
+ newEmptyA(),
+ },
+ rrSigs: []*dns.RRSIG{
+ newEmptyRRSig(dns.TypeA),
+ newEmptyRRSig(dns.TypeA),
+ },
+ },
+ },
+ },
+ "bad_signed_and_not_signed_rrsets": {
+ rrs: []dns.RR{
+ newEmptyRRSig(dns.TypeA),
+ newEmptyAAAA(),
+ newEmptyA(),
+ },
+ errWrapped: ErrRRSetSignedAndUnsigned,
+ errMessage: "mix of signed and unsigned RRSets: 1 signed and 1 unsigned RRSets",
+ },
+ "signed_rrsets": {
+ rrs: []dns.RR{
+ newEmptyRRSig(dns.TypeA),
+ newEmptyA(),
+ newEmptyAAAA(),
+ newEmptyRRSig(dns.TypeAAAA),
+ },
+ dnssecRRSets: []dnssecRRSet{
+ {
+ rrSigs: []*dns.RRSIG{newEmptyRRSig(dns.TypeA)},
+ rrSet: []dns.RR{
+ newEmptyA(),
+ },
+ },
+ {
+ rrSigs: []*dns.RRSIG{newEmptyRRSig(dns.TypeAAAA)},
+ rrSet: []dns.RR{
+ newEmptyAAAA(),
+ },
+ },
+ },
+ },
+ "not_signed_rrsets": {
+ rrs: []dns.RR{
+ newEmptyA(),
+ newEmptyAAAA(),
+ },
+ dnssecRRSets: []dnssecRRSet{
+ {
+ rrSet: []dns.RR{
+ newEmptyA(),
+ },
+ },
+ {
+ rrSet: []dns.RR{
+ newEmptyAAAA(),
+ },
+ },
+ },
+ },
+ }
+
+ for name, testCase := range testCases {
+ testCase := testCase
+ t.Run(name, func(t *testing.T) {
+ t.Parallel()
+
+ dnssecRRSets, err := groupRRs(testCase.rrs)
+
+ assert.Equal(t, testCase.dnssecRRSets, dnssecRRSets)
+ assert.ErrorIs(t, err, testCase.errWrapped)
+ if testCase.errWrapped != nil {
+ assert.EqualError(t, err, testCase.errMessage)
+ }
+ })
+ }
+}
+
+func newEmptyRRSig(typeCovered uint16) *dns.RRSIG {
+ return &dns.RRSIG{
+ Hdr: dns.RR_Header{
+ Name: "example.com.",
+ Rrtype: dns.TypeRRSIG,
+ },
+ TypeCovered: typeCovered,
+ SignerName: "example.com.",
+ }
+}
+
+func newEmptyA() *dns.A {
+ return &dns.A{
+ Hdr: dns.RR_Header{
+ Name: "example.com.",
+ Rrtype: dns.TypeA,
+ },
+ }
+}
+
+func newEmptyAAAA() *dns.AAAA {
+ return &dns.AAAA{
+ Hdr: dns.RR_Header{
+ Name: "example.com.",
+ Rrtype: dns.TypeAAAA,
+ },
+ }
+}
diff --git a/internal/dnssec/readme.md b/internal/dnssec/readme.md
new file mode 100644
index 00000000..02584f97
--- /dev/null
+++ b/internal/dnssec/readme.md
@@ -0,0 +1,39 @@
+# DNSSEC validator
+
+This package implements a DNSSEC validator middleware for a DNS forwarding server.
+It performs all queries through the next DNS handler middleware.... TODO
+
+## Comments
+
+Comments are aimed to be minimal and clear code is preferred over comments.
+However, there are a few references to specific sections of IETF RFCs especially when it comes to function comments.
+
+## Terminology used
+
+Teminology used in the code aims at being as close as possible to IETF RFCs.
+
+When this is unclear, there are a few specific rules used, for example:
+
+- Using `validate` or `verify`, see the `validate vs. verify` in [RFC4949](https://datatracker.ietf.org/doc/html/rfc4949)
+
+## Documentation used
+
+### IETF RFCs
+
+- [RFC4033](https://datatracker.ietf.org/doc/html/rfc4033)
+- [RFC5155 on NSEC3](https://datatracker.ietf.org/doc/html/rfc5155)
+- [RFC8624 on DNSKEY algorithms](https://datatracker.ietf.org/doc/html/rfc8624#section-3.1)
+
+### Blog posts
+
+
+
+
+### Videos
+
+- [DNSSEC Series #5 Record types, keys, signatures and NSEC](https://www.youtube.com/watch?v=FGs9kbdgMXE&t=2825s)
+
+## Tools used to help debugging
+
+- [DNSViz](https://dnsviz.net/)
+- [DNSSEC Analyzer](https://dnssec-analyzer.verisignlabs.com)
diff --git a/internal/dnssec/response.go b/internal/dnssec/response.go
new file mode 100644
index 00000000..b6061204
--- /dev/null
+++ b/internal/dnssec/response.go
@@ -0,0 +1,61 @@
+package dnssec
+
+import (
+ "fmt"
+
+ "github.com/miekg/dns"
+)
+
+type dnssecResponse struct {
+ answerRRSets []dnssecRRSet
+ authorityRRSets []dnssecRRSet
+ rcode int
+}
+
+func (d dnssecResponse) isNXDomain() bool {
+ return d.rcode == dns.RcodeNameError
+}
+
+func (d dnssecResponse) isNoData() bool {
+ return d.rcode == dns.RcodeSuccess && len(d.answerRRSets) == 0
+}
+
+func (d dnssecResponse) isSigned() bool {
+ // Note a slice of DNSSEC RRSets is either all signed or all unsigned.
+ switch {
+ case len(d.answerRRSets) > 0 && len(d.answerRRSets[0].rrSigs) == 0,
+ len(d.authorityRRSets) > 0 && len(d.authorityRRSets[0].rrSigs) == 0,
+ len(d.answerRRSets) == 0 && len(d.authorityRRSets) == 0:
+ return false
+ default:
+ return true
+ }
+}
+
+func (d dnssecResponse) onlyAnswerRRSet() (rrSet []dns.RR) {
+ if len(d.answerRRSets) != 1 {
+ panic(fmt.Sprintf("DNSSEC response has %d answer RRSets instead of 1",
+ len(d.answerRRSets)))
+ }
+ return d.answerRRSets[0].rrSet
+}
+
+func (d dnssecResponse) onlyAnswerRRSigs() (rrSigs []*dns.RRSIG) {
+ if len(d.answerRRSets) != 1 {
+ panic(fmt.Sprintf("DNSSEC response has %d answer RRSets instead of 1",
+ len(d.answerRRSets)))
+ }
+ return d.answerRRSets[0].rrSigs
+}
+
+func (d dnssecResponse) toDNSMsg(request *dns.Msg) (response *dns.Msg) {
+ response = new(dns.Msg)
+ response.SetRcode(request, d.rcode)
+ var ignoreTypes []uint16
+ if !isRequestAskingForDNSSEC(request) {
+ ignoreTypes = []uint16{dns.TypeNSEC, dns.TypeNSEC3, dns.TypeRRSIG}
+ }
+ response.Answer = dnssecRRSetsToRRs(d.answerRRSets, ignoreTypes...)
+ response.Ns = dnssecRRSetsToRRs(d.authorityRRSets, ignoreTypes...)
+ return response
+}
diff --git a/internal/dnssec/rrset.go b/internal/dnssec/rrset.go
new file mode 100644
index 00000000..e91e0259
--- /dev/null
+++ b/internal/dnssec/rrset.go
@@ -0,0 +1,97 @@
+package dnssec
+
+import (
+ "errors"
+ "fmt"
+
+ "github.com/miekg/dns"
+)
+
+// dnssecRRSet is a possibly signed RRSet for a certain
+// owner, type and class, containing at least one or more
+// RRs and zero or more RRSigs.
+// If the RRSet is unsigned, the rrSigs field is a slice
+// of length 0.
+type dnssecRRSet struct {
+ // rrSigs is the slice of RRSIGs for the RRSet.
+ // There can be more than one RRSIG, for example:
+ // dig +dnssec -t A xyzzy14.sdsmt.edu. @1.1.1.1
+ // returns 2 RRSIGs for the SOA authority section RRSet.
+ rrSigs []*dns.RRSIG
+ // rrSet cannot be empty.
+ rrSet []dns.RR
+}
+
+func (d dnssecRRSet) qtype() uint16 {
+ return d.rrSet[0].Header().Rrtype
+}
+
+func dnssecRRSetsToRRs(rrSets []dnssecRRSet, ignoreTypes ...uint16) (rrs []dns.RR) {
+ if len(rrSets) == 0 {
+ return nil
+ }
+
+ ignoreTypesMap := make(map[uint16]struct{}, len(ignoreTypes))
+ for _, ignoreType := range ignoreTypes {
+ ignoreTypesMap[ignoreType] = struct{}{}
+ }
+
+ minRRSetSize := len(rrSets) // 1 RR per owner, type and class
+ rrs = make([]dns.RR, 0, minRRSetSize)
+ for _, rrSet := range rrSets {
+ for _, rr := range rrSet.rrSet {
+ rrType := rr.Header().Rrtype
+ _, ignore := ignoreTypesMap[rrType]
+ if ignore {
+ continue
+ }
+ rrs = append(rrs, rr)
+ }
+
+ _, rrSigIgnored := ignoreTypesMap[dns.TypeRRSIG]
+ if rrSigIgnored {
+ continue
+ }
+
+ for _, rrSig := range rrSet.rrSigs {
+ _, ignored := ignoreTypesMap[rrSig.TypeCovered]
+ if ignored {
+ continue
+ }
+ rrs = append(rrs, rrSig)
+ }
+ }
+
+ if len(rrs) == 0 {
+ // All RRSets types were ignored
+ return nil
+ }
+
+ return rrs
+}
+
+var (
+ ErrRRSetsMissing = errors.New("no RRSet")
+ ErrRRSetsMultiple = errors.New("multiple RRSets")
+ ErrRRSetTypeUnexpected = errors.New("RRSet type unexpected")
+)
+
+func dnssecRRSetsIsSingleOfType(rrSets []dnssecRRSet, qType uint16) (err error) {
+ switch {
+ case len(rrSets) == 0:
+ return fmt.Errorf("%w", ErrRRSetsMissing)
+ case len(rrSets) == 1:
+ default:
+ return fmt.Errorf("%w: received %d RRSets instead of 1",
+ ErrRRSetsMultiple, len(rrSets))
+ }
+
+ rrSetType := rrSets[0].qtype()
+ if rrSetType != qType {
+ return fmt.Errorf("%w: received %s RRSet instead of %s",
+ ErrRRSetTypeUnexpected, dns.TypeToString[rrSetType],
+ dns.TypeToString[qType])
+ }
+
+ return nil
+}
diff --git a/internal/dnssec/rrsig.go b/internal/dnssec/rrsig.go
new file mode 100644
index 00000000..12c8d82b
--- /dev/null
+++ b/internal/dnssec/rrsig.go
@@ -0,0 +1,196 @@
+package dnssec
+
+import (
+ "errors"
+ "fmt"
+ "sort"
+ "time"
+
+ "github.com/miekg/dns"
+)
+
+func mustRRToRRSig(rr dns.RR) (rrSig *dns.RRSIG) {
+ rrSig, ok := rr.(*dns.RRSIG)
+ if !ok {
+ panic(fmt.Sprintf("RR is of type %T and not of type *dns.RRSIG", rr))
+ }
+ return rrSig
+}
+
+func rrSigToOwnerTypeCovered(rrSig *dns.RRSIG) (ownerTypeCovered string) {
+ return fmt.Sprintf("RRSIG for owner %s and type %s",
+ rrSig.Header().Name, dns.TypeToString[rrSig.TypeCovered])
+}
+
+// isRRSigForWildcard returns true if the RRSIG is for a wildcard.
+// This is detected by checking if the number of labels in the RRSIG
+// owner name is less than the number of labels in the RRSig owner name.
+// See https://datatracker.ietf.org/doc/html/rfc7129#section-5.3
+func isRRSigForWildcard(rrSig *dns.RRSIG) bool {
+ if rrSig == nil {
+ return false
+ }
+ ownerLabelsCount := uint8(dns.CountLabel(rrSig.Hdr.Name))
+ return rrSig.Labels < ownerLabelsCount
+}
+
+var (
+ ErrRRSigLabels = errors.New("RRSIG labels greater than owner labels")
+)
+
+// See https://datatracker.ietf.org/doc/html/rfc4035#section-5.3.1
+func rrsigInitialChecks(rrsig *dns.RRSIG) (err error) {
+ rrSetOwner := rrsig.Hdr.Name
+
+ err = rrSigCheckSignerName(rrsig)
+ if err != nil {
+ return err
+ }
+
+ if int(rrsig.Labels) > dns.CountLabel(rrSetOwner) {
+ // The number of labels in the RRset owner name MUST be greater than
+ // or equal to the value in the RRSIG RR's Labels field.
+ return fmt.Errorf("for %s: %w: RRSig labels field is %d and owner is %d labels",
+ rrSigToOwnerTypeCovered(rrsig), ErrRRSigLabels,
+ rrsig.Labels, dns.CountLabel(rrSetOwner))
+ }
+
+ return nil
+}
+
+func verifyRRSetsRRSig(answerRRSets []dnssecRRSet, keyTagToDNSKey map[uint16]*dns.DNSKEY) (err error) {
+ for _, signedRRSet := range answerRRSets {
+ err = verifyRRSetRRSigs(signedRRSet.rrSet,
+ signedRRSet.rrSigs, keyTagToDNSKey)
+ if err != nil {
+ return err
+ }
+ }
+
+ return nil
+}
+
+func verifyRRSetRRSigs(rrSet []dns.RR, rrSigs []*dns.RRSIG,
+ keyTagToDNSKey map[uint16]*dns.DNSKEY) (
+ err error) {
+ if len(rrSet) == 0 || len(rrSigs) == 0 {
+ panic("no rrs or rrsigs")
+ }
+
+ if len(rrSigs) == 1 {
+ return verifyRRSetRRSig(rrSet, rrSigs[0], keyTagToDNSKey)
+ }
+
+ // Multiple RRSIGs for the same RRSet, sort them by algorithm preference
+ // and try each one until one succeeds. This is rather undocumented,
+ // but one signature verified should be enough to validate the RRSet,
+ // even if other signatures fail to verify successfully.
+ sortRRSIGsByAlgo(rrSigs)
+
+ errs := new(joinedErrors)
+ for _, rrSig := range rrSigs {
+ if !rrSig.ValidityPeriod(time.Now()) {
+ errs.add(fmt.Errorf("%w", ErrRRSigExpired))
+ continue
+ }
+
+ keyTag := rrSig.KeyTag
+ dnsKey, ok := keyTagToDNSKey[keyTag]
+ if !ok {
+ errs.add(fmt.Errorf("%w: in %d DNSKEY(s) for key tag %d",
+ ErrRRSigDNSKeyTag, len(keyTagToDNSKey), keyTag))
+ continue
+ }
+
+ err = rrSig.Verify(dnsKey, rrSet)
+ if err != nil {
+ errs.add(err)
+ continue
+ }
+
+ return nil
+ }
+
+ return fmt.Errorf("%d RRSIGs failed to validate the RRSet: %w",
+ len(rrSigs), errs)
+}
+
+var (
+ ErrRRSigDNSKeyTag = errors.New("DNSKEY not found")
+ ErrRRSigExpired = errors.New("RRSIG has expired")
+)
+
+func verifyRRSetRRSig(rrSet []dns.RR, rrSig *dns.RRSIG,
+ keyTagToDNSKey map[uint16]*dns.DNSKEY) (err error) {
+ if !rrSig.ValidityPeriod(time.Now()) {
+ return fmt.Errorf("%w", ErrRRSigExpired)
+ }
+
+ keyTag := rrSig.KeyTag
+ dnsKey, ok := keyTagToDNSKey[keyTag]
+ if !ok {
+ return fmt.Errorf("%w: in %d DNSKEY(s) for key tag %d",
+ ErrRRSigDNSKeyTag, len(keyTagToDNSKey), keyTag)
+ }
+
+ err = rrSig.Verify(dnsKey, rrSet)
+ if err != nil {
+ return err
+ }
+
+ return nil
+}
+
+// sortRRSIGsByAlgo sorts RRSIGs by algorithm preference.
+func sortRRSIGsByAlgo(rrSigs []*dns.RRSIG) {
+ sort.Slice(rrSigs, func(i, j int) bool {
+ return lessDNSKeyAlgorithm(rrSigs[i].Algorithm, rrSigs[j].Algorithm)
+ })
+}
+
+var (
+ ErrRRSigSignerName = errors.New("signer name is not valid")
+)
+
+// The RRSIG RR's Signer's Name field MUST be the
+// name of the zone that contains the RRset.
+func rrSigCheckSignerName(rrSig *dns.RRSIG) (err error) {
+ var validSignerNames []string
+ switch rrSig.TypeCovered {
+ case dns.TypeDS, dns.TypeCNAME, dns.TypeNSEC3:
+ validSignerNames = []string{parentName(rrSig.Hdr.Name)}
+ default:
+ // For NSEC RRs, the signer name must be the apex name which
+ // can be the owner or the parent of the owner of the RRSIG.
+ // For example:
+ // p.example.com. 3601 IN NSEC r.example.com. A RRSIG NSEC
+ // p.example.com. 3601 IN RRSIG NSEC 13 3 3601 20240111000000 20231221000000 42950 example.com. 0se..m GY..w==
+ // example.com. 3601 IN NSEC l.example.com. A NS SOA RRSIG NSEC DNSKEY
+ // example.com. 3601 IN RRSIG NSEC 13 2 3601 20240111000000 20231221000000 42950 example.com. pe..B 4V..Q==
+
+ // For other RRs, such as A, the signer name must be the owner
+ // or the parent of the owner, for example for sigok.ippacket.stream.
+ // the A record RRSIG owner is sigok.rsa2048-sha256.ippacket.stream.
+ // and signer name is rsa2048-sha256.ippacket.stream.
+ validSignerNames = []string{rrSig.Hdr.Name, parentName(rrSig.Hdr.Name)}
+ }
+
+ if isOneOf(rrSig.SignerName, validSignerNames...) {
+ return nil
+ }
+
+ quoteStrings(validSignerNames)
+ return fmt.Errorf("for %s: %w: %q should be %s",
+ rrSigToOwnerTypeCovered(rrSig), ErrRRSigSignerName,
+ rrSig.SignerName, orStrings(validSignerNames))
+}
+
+func parentName(name string) (parent string) {
+ const offset = 0
+ nextLabelStart, end := dns.NextLabel(name, offset)
+ if end {
+ // parent of 'tld.' is '.' and parent of '.' is '.'
+ return "."
+ }
+ return name[nextLabelStart:]
+}
diff --git a/internal/dnssec/rrsig_test.go b/internal/dnssec/rrsig_test.go
new file mode 100644
index 00000000..4f6a68ae
--- /dev/null
+++ b/internal/dnssec/rrsig_test.go
@@ -0,0 +1,153 @@
+package dnssec
+
+import (
+ "testing"
+
+ "github.com/miekg/dns"
+ "github.com/stretchr/testify/assert"
+)
+
+func Test_sortRRSIGsByAlgo(t *testing.T) {
+ t.Parallel()
+
+ testCases := map[string]struct {
+ rrSigs []*dns.RRSIG
+ expected []*dns.RRSIG
+ }{
+ "empty": {},
+ "single": {
+ rrSigs: []*dns.RRSIG{
+ {Algorithm: dns.RSASHA1},
+ },
+ expected: []*dns.RRSIG{
+ {Algorithm: dns.RSASHA1},
+ },
+ },
+ "multiple": {
+ rrSigs: []*dns.RRSIG{
+ {Algorithm: dns.ED25519},
+ {Algorithm: dns.RSASHA1},
+ {Algorithm: dns.ECCGOST},
+ {Algorithm: dns.RSASHA512},
+ {Algorithm: dns.ECDSAP384SHA384},
+ {Algorithm: dns.DSA},
+ },
+ expected: []*dns.RRSIG{
+ {Algorithm: dns.ED25519},
+ {Algorithm: dns.ECDSAP384SHA384},
+ {Algorithm: dns.RSASHA1},
+ {Algorithm: dns.RSASHA512},
+ {Algorithm: dns.ECCGOST},
+ {Algorithm: dns.DSA},
+ },
+ },
+ }
+
+ for name, testCase := range testCases {
+ testCase := testCase
+ t.Run(name, func(t *testing.T) {
+ t.Parallel()
+
+ sortRRSIGsByAlgo(testCase.rrSigs)
+
+ assert.Equal(t, testCase.expected, testCase.rrSigs)
+ })
+ }
+}
+
+func Test_rrSigCheckSignerName(t *testing.T) {
+ t.Parallel()
+
+ testCases := map[string]struct {
+ rrSig *dns.RRSIG
+ errWrapped error
+ errMessage string
+ }{
+ "a_signer_is_owner": {
+ rrSig: &dns.RRSIG{
+ Hdr: dns.RR_Header{
+ Name: "example.com.",
+ },
+ TypeCovered: dns.TypeA,
+ SignerName: "example.com.",
+ },
+ },
+ "a_signer_is_parent": {
+ rrSig: &dns.RRSIG{
+ Hdr: dns.RR_Header{
+ Name: "example.com.",
+ },
+ TypeCovered: dns.TypeA,
+ SignerName: "com.",
+ },
+ },
+ "a_signer_is_invalid": {
+ rrSig: &dns.RRSIG{
+ Hdr: dns.RR_Header{
+ Name: "example.com.",
+ },
+ TypeCovered: dns.TypeA,
+ SignerName: ".",
+ },
+ errWrapped: ErrRRSigSignerName,
+ errMessage: `for RRSIG for owner example.com. and type A: ` +
+ `signer name is not valid: "." should be "example.com." or "com."`,
+ },
+ "ds_signer_is_parent": {
+ rrSig: &dns.RRSIG{
+ Hdr: dns.RR_Header{
+ Name: "example.com.",
+ },
+ TypeCovered: dns.TypeDS,
+ SignerName: "com.",
+ },
+ },
+ "ds_signer_is_owner": {
+ rrSig: &dns.RRSIG{
+ Hdr: dns.RR_Header{
+ Name: "example.com.",
+ },
+ TypeCovered: dns.TypeDS,
+ SignerName: "example.com.",
+ },
+ errWrapped: ErrRRSigSignerName,
+ errMessage: `for RRSIG for owner example.com. and type DS: ` +
+ `signer name is not valid: "example.com." should be "com."`,
+ },
+ "cname_signer_is_parent": {
+ rrSig: &dns.RRSIG{
+ Hdr: dns.RR_Header{
+ Name: "example.com.",
+ },
+ TypeCovered: dns.TypeCNAME,
+ SignerName: "com.",
+ },
+ },
+ "cname_signer_is_owner": {
+ rrSig: &dns.RRSIG{
+ Hdr: dns.RR_Header{
+ Name: "example.com.",
+ },
+ TypeCovered: dns.TypeCNAME,
+ SignerName: "example.com.",
+ },
+ errWrapped: ErrRRSigSignerName,
+ errMessage: `for RRSIG for owner example.com. and type CNAME: ` +
+ `signer name is not valid: "example.com." should be "com."`,
+ },
+ }
+
+ for name, testCase := range testCases {
+ testCase := testCase
+ t.Run(name, func(t *testing.T) {
+ t.Parallel()
+
+ err := rrSigCheckSignerName(testCase.rrSig)
+
+ assert.ErrorIs(t, err, testCase.errWrapped)
+ if testCase.errWrapped != nil {
+ assert.EqualError(t, err, testCase.errMessage)
+ }
+ })
+ }
+}
diff --git a/internal/dnssec/signeddata.go b/internal/dnssec/signeddata.go
new file mode 100644
index 00000000..78bd5747
--- /dev/null
+++ b/internal/dnssec/signeddata.go
@@ -0,0 +1,9 @@
+package dnssec
+
+type signedData struct {
+ zone string
+ // TODO do we need this class field? Maybe for caching??
+ class uint16
+ dnsKeyResponse dnssecResponse
+ dsResponse dnssecResponse
+}
diff --git a/internal/dnssec/strings.go b/internal/dnssec/strings.go
new file mode 100644
index 00000000..69d06cd9
--- /dev/null
+++ b/internal/dnssec/strings.go
@@ -0,0 +1,31 @@
+package dnssec
+
+import "fmt"
+
+func quoteStrings(elements []string) {
+ for i := range elements {
+ elements[i] = "\"" + elements[i] + "\""
+ }
+}
+
+func orStrings[T comparable](elements []T) (result string) {
+ return joinStrings(elements, "or")
+}
+
+func joinStrings[T comparable](elements []T, lastJoin string) (result string) {
+ if len(elements) == 0 {
+ return ""
+ }
+
+ result = fmt.Sprint(elements[0])
+ for i := 1; i < len(elements); i++ {
+ lastElement := i == len(elements)-1
+ if lastElement {
+ result += " " + lastJoin + " " + fmt.Sprint(elements[i])
+ continue
+ }
+ result += ", " + fmt.Sprint(elements[i])
+ }
+
+ return result
+}
diff --git a/internal/dnssec/validate.go b/internal/dnssec/validate.go
new file mode 100644
index 00000000..64f3b40b
--- /dev/null
+++ b/internal/dnssec/validate.go
@@ -0,0 +1,239 @@
+package dnssec
+
+import (
+ "errors"
+ "fmt"
+ "strings"
+
+ "github.com/miekg/dns"
+)
+
+// verify uses the zone data in the signed zone and its parent signed zones
+// to verify the DNSSEC chain of trust.
+// It starts the verification with the RRSet given as argument, and,
+// assuming a signature is valid, it walks through the slice of signed
+// zones checking the RRSIGs on the DNSKEY and DS resource record sets.
+func validateWithChain(desiredZone string, qType uint16,
+ desiredResponse dnssecResponse, chain []signedData) (err error) {
+ // Verify the root zone "."
+ rootZone := chain[0]
+
+ // Verify DNSKEY RRSet with its RRSIG and the DNSKEY matching
+ // the RRSIG key tag.
+ rootZoneKeyTagToDNSKey := makeKeyTagToDNSKey(rootZone.dnsKeyResponse.onlyAnswerRRSet())
+ err = verifyRRSetRRSigs(rootZone.dnsKeyResponse.onlyAnswerRRSet(),
+ rootZone.dnsKeyResponse.onlyAnswerRRSigs(), rootZoneKeyTagToDNSKey)
+ if err != nil {
+ return fmt.Errorf("verifying DNSKEY records for the root zone: %w",
+ err)
+ }
+
+ // Verify the root anchor digest against the digest of the DS
+ // calculated from the DNSKEY of the root zone matching the
+ // root anchor key tag.
+ const (
+ rootAnchorKeyTag = 20326
+ rootAnchorDigest = "E06D44B80B8F1D39A95C0B0D7C65D08458E880409BBC683457104237C7F8EC8D"
+ )
+ rootAnchor := &dns.DS{
+ Algorithm: dns.RSASHA256,
+ DigestType: dns.SHA256,
+ KeyTag: rootAnchorKeyTag,
+ Digest: rootAnchorDigest,
+ }
+ err = verifyDS(rootAnchor, rootZoneKeyTagToDNSKey)
+ if err != nil {
+ return fmt.Errorf("verifying the root anchor: %w", err)
+ }
+
+ wildcardName := extractWildcardExpansion(desiredResponse.answerRRSets)
+ if wildcardName != "" {
+ wildcardLabelsCount := dns.CountLabel(wildcardName)
+ chain = chain[:wildcardLabelsCount]
+ }
+
+ parentZoneInsecure := false
+ for i := 1; i < len(chain); i++ {
+ // Iterate in this order: "com.", "example.com.", "abc.example.com."
+ // Note the chain may not include the desired zone if one of its parent
+ // zone is unsigned. Checking a parent zone is indeed unsigned
+ // with DS-associated NSEC/NSEC3 RRSets also verifies the desired
+ // zone is unsigned.
+ zoneData := chain[i]
+ parentZoneData := chain[i-1]
+
+ switch {
+ case zoneData.dsResponse.isNXDomain():
+ parentKeyTagToDNSKey := makeKeyTagToDNSKey(parentZoneData.dnsKeyResponse.onlyAnswerRRSet())
+ err = validateNxDomain(zoneData.zone, zoneData.dsResponse.authorityRRSets, parentKeyTagToDNSKey)
+ if err != nil {
+ return fmt.Errorf("validating NXDOMAIN DS response: %w", err)
+ }
+ // no need to continue the verification for this zone since
+ // child zones are unsigned.
+ parentZoneInsecure = true
+ case zoneData.dsResponse.isNoData():
+ parentKeyTagToDNSKey := makeKeyTagToDNSKey(parentZoneData.dnsKeyResponse.onlyAnswerRRSet())
+ err = validateNoDataDS(zoneData.zone, zoneData.dsResponse.authorityRRSets, parentKeyTagToDNSKey)
+ if err != nil {
+ return fmt.Errorf("validating no data DS response: %w", err)
+ }
+
+ // no need to continue the verification for this zone since
+ // child zones are unsigned.
+ parentZoneInsecure = true
+ default: // signed zone
+ }
+
+ if parentZoneInsecure {
+ break
+ }
+
+ // Validate DNSKEY RRSet with its RRSIG and the DNSKEY matching
+ // the RRSIG key tag. Note a zone should only have a DNSKEY RRSet
+ // if it has a DS RRSet.
+ keyTagToDNSKey := makeKeyTagToDNSKey(zoneData.dnsKeyResponse.onlyAnswerRRSet())
+ err = verifyRRSetRRSigs(zoneData.dnsKeyResponse.onlyAnswerRRSet(),
+ zoneData.dnsKeyResponse.onlyAnswerRRSigs(),
+ keyTagToDNSKey)
+ if err != nil {
+ return fmt.Errorf("validating DNSKEY RRSet for zone %s: %w",
+ zoneData.zone, err)
+ }
+
+ // Validate DS RRSet with its RRSIG and the DNSKEY of its parent zone
+ // matching the RRSIG key tag.
+ parentKeyTagToDNSKey := makeKeyTagToDNSKey(parentZoneData.dnsKeyResponse.onlyAnswerRRSet())
+ err = verifyRRSetRRSigs(zoneData.dsResponse.onlyAnswerRRSet(),
+ zoneData.dsResponse.onlyAnswerRRSigs(), parentKeyTagToDNSKey)
+ if err != nil {
+ return fmt.Errorf("validating DS RRSet for zone %s: %w",
+ zoneData.zone, err)
+ }
+
+ // Validate DS RRSet digests with their corresponding DNSKEYs.
+ err = verifyDSRRSet(zoneData.dsResponse.onlyAnswerRRSet(), keyTagToDNSKey)
+ if err != nil {
+ return fmt.Errorf("verifying DS RRSet for zone %s: %w",
+ zoneData.zone, err)
+ }
+ }
+
+ if !desiredResponse.isSigned() && !parentZoneInsecure {
+ // The desired query returned an insecure response
+ // (unsigned answers or no NSEC/NSEC3 RRSets) and
+ // no parent zone was found to be unsigned, meaning this
+ // is bogus.
+ return fmt.Errorf("%w: desired query response is unsigned "+
+ "but no parent zone was found to be insecure", ErrBogus)
+ }
+
+ if parentZoneInsecure {
+ // Whether the desired query is signed or not, if a parent zone
+ // is insecure, the desired query is insecure.
+ // For example IN A textsecure-service.whispersystems.org. has NSEC
+ // signed by whispersystems.org., which has DNSKEYs but no DS record.
+ return nil
+ }
+
+ // From this point, the desiredResponse is signed.
+
+ // Note we validate the desired zone last since there might be a
+ // break in the chain, where there is no DNSKEY for the parent zone
+ // of the desired zone which has a DS RRSet.
+ // For example for textsecure-service.whispersystems.org.
+ var lastSecureZoneData signedData
+ for i := len(chain) - 1; i >= 0; i-- {
+ zoneData := chain[i]
+ if len(zoneData.dsResponse.onlyAnswerRRSet()) > 0 {
+ lastSecureZoneData = zoneData
+ break
+ }
+ }
+
+ lastSecureKeyTagToDNSKey := makeKeyTagToDNSKey(lastSecureZoneData.dnsKeyResponse.onlyAnswerRRSet())
+ switch {
+ case desiredResponse.isNXDomain():
+ err = validateNxDomain(desiredZone, desiredResponse.authorityRRSets,
+ lastSecureKeyTagToDNSKey)
+ if err != nil {
+ return fmt.Errorf("validating negative NXDOMAIN response: %w", err)
+ }
+ return nil
+ case desiredResponse.isNoData():
+ err = validateNoData(desiredZone, qType, desiredResponse.authorityRRSets,
+ lastSecureKeyTagToDNSKey)
+ if err != nil {
+ return fmt.Errorf("validating negative NODATA response: %w", err)
+ }
+ return nil
+ default:
+ // Verify the desired RRSets with the DNSKEY of the desired
+ // zone matching the RRSIG key tag.
+ err = verifyRRSetsRRSig(desiredResponse.answerRRSets, lastSecureKeyTagToDNSKey)
+ if err != nil {
+ return fmt.Errorf("verifying RRSets with RRSigs: %w", err)
+ }
+
+ if wildcardName == "" { // no wildcard expansion
+ return nil
+ }
+
+ err = verifyRRSetsRRSig(desiredResponse.authorityRRSets, lastSecureKeyTagToDNSKey)
+ if err != nil {
+ return fmt.Errorf("verifying authority section RRSets with RRSigs: %w", err)
+ }
+
+ err = validateWildcardExpansion(desiredZone, desiredResponse.authorityRRSets)
+ if err != nil {
+ return fmt.Errorf("validating wildcard expansion: %w", err)
+ }
+
+ return nil
+ }
+}
+
+// verifyDSRRSet verifies the digest of each received DS
+// is equal to the digest of the calculated DS obtained
+// from the DNSKEY (KSK) matching the received DS key tag.
+func verifyDSRRSet(dsRRSet []dns.RR,
+ keyTagToDNSKey map[uint16]*dns.DNSKEY) (err error) {
+ for _, rr := range dsRRSet {
+ ds := mustRRToDS(rr)
+ err = verifyDS(ds, keyTagToDNSKey)
+ if err != nil {
+ return fmt.Errorf("verifying DS record: %w", err)
+ }
+ }
+ return nil
+}
+
+var (
+ ErrDNSKeyNotFound = errors.New("DNSKEY resource record not found")
+ ErrDNSKeyToDS = errors.New("failed to calculate DS from DNSKEY")
+ ErrDNSKeyDSMismatch = errors.New("DS does not match DNS key")
+)
+
+func verifyDS(receivedDS *dns.DS,
+ keyTagToDNSKey map[uint16]*dns.DNSKEY) error {
+ // Note keyTagToDNSKey only contains ZSKs.
+ dnsKey, ok := keyTagToDNSKey[receivedDS.KeyTag]
+ if !ok {
+ return fmt.Errorf("for RRSIG key tag %d: %w",
+ receivedDS.KeyTag, ErrDNSKeyNotFound)
+ }
+
+ calculatedDS := dnsKey.ToDS(receivedDS.DigestType)
+ if calculatedDS == nil {
+ return fmt.Errorf("%w: for DNSKEY name %s and digest type %d",
+ ErrDNSKeyToDS, dnsKey.Header().Name, receivedDS.DigestType)
+ }
+
+ if !strings.EqualFold(receivedDS.Digest, calculatedDS.Digest) {
+ return fmt.Errorf("%w: DS record has digest %s "+
+ "but DNSKEY calculated DS has digest %s",
+ ErrDNSKeyDSMismatch, receivedDS.Digest, calculatedDS.Digest)
+ }
+
+ return nil
+}
diff --git a/internal/dnssec/wildcard.go b/internal/dnssec/wildcard.go
new file mode 100644
index 00000000..048d5368
--- /dev/null
+++ b/internal/dnssec/wildcard.go
@@ -0,0 +1,68 @@
+package dnssec
+
+import (
+ "errors"
+ "fmt"
+ "strings"
+
+ "github.com/miekg/dns"
+)
+
+// extractWildcardExpansion returns an empty string if no wildcard expansion
+// is found, otherwise it returns the wildcard name in the format "*.domain.tld.".
+func extractWildcardExpansion(signedRRSets []dnssecRRSet) (wildcardName string) {
+ // TODO simplify this once tested live
+ var expandedQtype uint16 // TODO remove safety check
+ for _, signedRRSet := range signedRRSets {
+ for _, rrSig := range signedRRSet.rrSigs {
+ if !isRRSigForWildcard(rrSig) {
+ continue
+ }
+
+ labels := dns.SplitDomainName(rrSig.Hdr.Name)
+ newWildcardName := dns.Fqdn("*." + strings.Join(labels[len(labels)-int(rrSig.Labels):], "."))
+ if wildcardName != "" && wildcardName != newWildcardName {
+ globalDebugLogger.Errorf("wildcard expanded multiple names: %s and %s",
+ wildcardName, newWildcardName)
+ }
+ wildcardName = newWildcardName
+
+ if expandedQtype != dns.TypeNone && expandedQtype != rrSig.TypeCovered {
+ globalDebugLogger.Errorf("wildcard expanded multiple types: %s and %s",
+ dns.TypeToString[expandedQtype], dns.TypeToString[rrSig.TypeCovered])
+ }
+ expandedQtype = rrSig.TypeCovered
+ }
+ }
+
+ return wildcardName
+}
+
+var (
+ ErrNSECxMissing = errors.New("NSEC or NSEC3 record missing")
+)
+
+// For wildcard considerations in positive responses, see:
+// - https://datatracker.ietf.org/doc/html/rfc2535#section-5.3
+// - https://datatracker.ietf.org/doc/html/rfc4035#section-5.3.4
+// - https://datatracker.ietf.org/doc/html/rfc4035#section-3.1.3.3
+func validateWildcardExpansion(expandedQname string,
+ authoritySection []dnssecRRSet) (err error) {
+ nsec3RRs, wildcard := extractNSEC3s(authoritySection)
+ if len(nsec3RRs) > 0 {
+ nsec3RRs, err = nsec3InitialChecks(nsec3RRs)
+ if err != nil {
+ return fmt.Errorf("initial NSEC3 checks: %w", err)
+ }
+ globalDebugLogger.Infof("validating wildcard expansion for %s: "+
+ "NSEC3 wildcard: %t", expandedQname, wildcard)
+ return nsec3ValidateWildcard(expandedQname, nsec3RRs)
+ }
+
+ nsecRRs := extractNSECs(authoritySection)
+ if len(nsecRRs) > 0 {
+ return nsecValidateNxDomain(expandedQname, nsecRRs)
+ }
+
+ return fmt.Errorf("%w", ErrNSECxMissing)
+}
diff --git a/internal/setup/dns.go b/internal/setup/dns.go
index 501b3578..17214990 100644
--- a/internal/setup/dns.go
+++ b/internal/setup/dns.go
@@ -6,6 +6,7 @@ import (
"github.com/qdm12/dns/v2/internal/config"
"github.com/qdm12/dns/v2/pkg/metrics/prometheus"
cachemiddleware "github.com/qdm12/dns/v2/pkg/middlewares/cache"
+ "github.com/qdm12/dns/v2/pkg/middlewares/dnssec"
filtermiddleware "github.com/qdm12/dns/v2/pkg/middlewares/filter"
"github.com/qdm12/dns/v2/pkg/middlewares/localdns"
"github.com/qdm12/log"
@@ -19,14 +20,15 @@ type Service interface {
func DNS(userSettings config.Settings, ipv6Support bool, //nolint:ireturn
cache Cache, filter Filter, loggerConstructor LoggerConstructor,
- promRegistry PrometheusRegistry) (server Service, err error) {
+ promRegistry PrometheusRegistry, dnssecEnabled bool) (
+ server Service, err error) {
commonPrometheus := prometheus.Settings{
Prefix: *userSettings.Metrics.Prometheus.Subsystem,
Registry: promRegistry,
}
middlewares, err := setupMiddlewares(userSettings, cache,
- filter, loggerConstructor, commonPrometheus)
+ filter, loggerConstructor, commonPrometheus, dnssecEnabled)
if err != nil {
return nil, fmt.Errorf("setting up middlewares: %w", err)
}
@@ -56,7 +58,8 @@ func DNS(userSettings config.Settings, ipv6Support bool, //nolint:ireturn
}
func setupMiddlewares(userSettings config.Settings, cache Cache,
- filter Filter, loggerConstructor log.ChildConstructor, commonPrometheus prometheus.Settings) (
+ filter Filter, loggerConstructor log.ChildConstructor,
+ commonPrometheus prometheus.Settings, dnssecEnabled bool) (
middlewares []Middleware, err error) {
cacheMiddleware, err := cachemiddleware.New(cachemiddleware.Settings{Cache: cache})
if err != nil {
@@ -64,6 +67,17 @@ func setupMiddlewares(userSettings config.Settings, cache Cache,
}
middlewares = append(middlewares, cacheMiddleware)
+ if dnssecEnabled {
+ settings := dnssec.Settings{
+ Logger: loggerConstructor.New(log.SetComponent("DNSSEC")),
+ }
+ dnssecMiddleware, err := dnssec.New(settings)
+ if err != nil {
+ return nil, fmt.Errorf("creating DNSSEC middleware: %w", err)
+ }
+ middlewares = append(middlewares, dnssecMiddleware)
+ }
+
if *userSettings.LocalDNS.Enabled {
localDNSMiddleware, err := localdns.New(localdns.Settings{
Resolvers: userSettings.LocalDNS.Resolvers, // possibly auto-detected
diff --git a/pkg/middlewares/dnssec/handler.go b/pkg/middlewares/dnssec/handler.go
new file mode 100644
index 00000000..f06c1eb2
--- /dev/null
+++ b/pkg/middlewares/dnssec/handler.go
@@ -0,0 +1,30 @@
+package dnssec
+
+import (
+ "github.com/miekg/dns"
+ "github.com/qdm12/dns/v2/internal/dnssec"
+)
+
+type handler struct {
+ // Injected from middleware
+ logger Logger
+ next dns.Handler
+}
+
+func newHandler(logger Logger, next dns.Handler) *handler {
+ return &handler{
+ logger: logger,
+ next: next,
+ }
+}
+
+func (h *handler) ServeDNS(w dns.ResponseWriter, request *dns.Msg) {
+ response, err := dnssec.Validate(request, h.next)
+ if err != nil {
+ h.logger.Warn(err.Error())
+ response = new(dns.Msg)
+ response.SetRcode(request, dns.RcodeServerFailure)
+ }
+
+ _ = w.WriteMsg(response)
+}
diff --git a/pkg/middlewares/dnssec/interfaces.go b/pkg/middlewares/dnssec/interfaces.go
new file mode 100644
index 00000000..296bedb1
--- /dev/null
+++ b/pkg/middlewares/dnssec/interfaces.go
@@ -0,0 +1,5 @@
+package dnssec
+
+type Logger interface {
+ Warn(message string)
+}
diff --git a/pkg/middlewares/dnssec/middleware.go b/pkg/middlewares/dnssec/middleware.go
new file mode 100644
index 00000000..68865bf1
--- /dev/null
+++ b/pkg/middlewares/dnssec/middleware.go
@@ -0,0 +1,47 @@
+package dnssec
+
+import (
+ "fmt"
+ "sync/atomic"
+
+ "github.com/miekg/dns"
+)
+
+// Middleware implements a DNSSEC validator.
+type Middleware struct {
+ settings Settings
+ wrapping atomic.Bool
+}
+
+func New(settings Settings) (middleware *Middleware, err error) {
+ settings.SetDefaults()
+
+ err = settings.Validate()
+ if err != nil {
+ return nil, fmt.Errorf("validating settings: %w", err)
+ }
+
+ return &Middleware{
+ settings: settings,
+ }, nil
+}
+
+func (m *Middleware) String() string {
+ return "DNSSEC validator"
+}
+
+// Wrap wraps the DNS handler with the middleware.
+func (m *Middleware) Wrap(next dns.Handler) dns.Handler { //nolint:ireturn
+ previousWrapping := m.wrapping.Swap(true)
+ if previousWrapping {
+ panic("DNSSEC middleware cannot wrap more than once")
+ }
+
+ handler := newHandler(m.settings.Logger, next)
+ return handler
+}
+
+// Stop is a no-op for the DNSSEC middleware.
+func (m *Middleware) Stop() (err error) {
+ return nil
+}
diff --git a/pkg/middlewares/dnssec/settings.go b/pkg/middlewares/dnssec/settings.go
new file mode 100644
index 00000000..66cf4e3b
--- /dev/null
+++ b/pkg/middlewares/dnssec/settings.go
@@ -0,0 +1,28 @@
+package dnssec
+
+import (
+ "github.com/qdm12/dns/v2/pkg/log/noop"
+ "github.com/qdm12/gosettings"
+ "github.com/qdm12/gotree"
+)
+
+type Settings struct {
+ // Logger is the logger to use.
+ // It defaults to a No-op implementation.
+ Logger Logger
+}
+
+func (s *Settings) SetDefaults() {
+ s.Logger = gosettings.DefaultComparable[Logger](s.Logger, noop.New())
+}
+
+func (s *Settings) Validate() error { return nil }
+
+func (s *Settings) String() string {
+ return s.ToLinesNode().String()
+}
+
+func (s *Settings) ToLinesNode() (node *gotree.Node) {
+ node = gotree.New("DNSSEC settings:")
+ return node
+}