diff --git a/x/mongo/driver/connstring/initial_dns_seedlist_discovery_prose_test.go b/x/mongo/driver/connstring/initial_dns_seedlist_discovery_prose_test.go new file mode 100644 index 0000000000..a2ddf14a88 --- /dev/null +++ b/x/mongo/driver/connstring/initial_dns_seedlist_discovery_prose_test.go @@ -0,0 +1,116 @@ +// Copyright (C) MongoDB, Inc. 2024-present. +// +// Licensed under the Apache License, Version 2.0 (the "License"); you may +// not use this file except in compliance with the License. You may obtain +// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 + +package connstring + +import ( + "net" + "testing" + + "go.mongodb.org/mongo-driver/v2/internal/assert" + "go.mongodb.org/mongo-driver/v2/x/mongo/driver/dns" +) + +func TestInitialDNSSeedlistDiscoveryProse(t *testing.T) { + newTestParser := func(record string) *parser { + return &parser{&dns.Resolver{ + LookupSRV: func(_, _, _ string) (string, []*net.SRV, error) { + return "", []*net.SRV{ + { + Target: record, + Port: 27017, + }, + }, nil + }, + LookupTXT: func(string) ([]string, error) { + return nil, nil + }, + }} + } + + t.Run("1. Allow SRVs with fewer than 3 . separated parts", func(t *testing.T) { + t.Parallel() + + cases := []struct { + record string + uri string + }{ + {"test_1.localhost", "mongodb+srv://localhost"}, + {"test_1.mongo.local", "mongodb+srv://mongo.local"}, + } + for _, c := range cases { + c := c + t.Run(c.uri, func(t *testing.T) { + t.Parallel() + + _, err := newTestParser(c.record).parse(c.uri) + assert.NoError(t, err, "expected no URI parsing error, got %v", err) + }) + } + }) + t.Run("2. Throw when return address does not end with SRV domain", func(t *testing.T) { + t.Parallel() + + cases := []struct { + record string + uri string + }{ + {"localhost.mongodb", "mongodb+srv://localhost"}, + {"test_1.evil.local", "mongodb+srv://mongo.local"}, + {"blogs.evil.com", "mongodb+srv://blogs.mongodb.com"}, + } + for _, c := range cases { + c := c + t.Run(c.uri, func(t *testing.T) { + t.Parallel() + + _, err := newTestParser(c.record).parse(c.uri) + assert.ErrorContains(t, err, "Domain suffix from SRV record not matched input domain") + }) + } + }) + t.Run("3. Throw when return address is identical to SRV hostname", func(t *testing.T) { + t.Parallel() + + cases := []struct { + record string + uri string + }{ + {"localhost", "mongodb+srv://localhost"}, + {"mongo.local", "mongodb+srv://mongo.local"}, + } + for _, c := range cases { + c := c + t.Run(c.uri, func(t *testing.T) { + t.Parallel() + + _, err := newTestParser(c.record).parse(c.uri) + assert.ErrorContains(t, err, "DNS name must contain at least") + }) + } + }) + t.Run("4. Throw when return address does not contain . separating shared part of domain", func(t *testing.T) { + t.Parallel() + + cases := []struct { + record string + uri string + }{ + {"test_1.cluster_1localhost", "mongodb+srv://localhost"}, + {"test_1.my_hostmongo.local", "mongodb+srv://mongo.local"}, + {"cluster.testmongodb.com", "mongodb+srv://blogs.mongodb.com"}, + } + for _, c := range cases { + c := c + t.Run(c.uri, func(t *testing.T) { + t.Parallel() + + _, err := newTestParser(c.record).parse(c.uri) + assert.ErrorContains(t, err, "Domain suffix from SRV record not matched input domain") + }) + } + }) +} diff --git a/x/mongo/driver/dns/dns.go b/x/mongo/driver/dns/dns.go index 9334d493ed..4524af2794 100644 --- a/x/mongo/driver/dns/dns.go +++ b/x/mongo/driver/dns/dns.go @@ -113,15 +113,18 @@ func (r *Resolver) fetchSeedlistFromSRV(host string, srvName string, stopOnErr b func validateSRVResult(recordFromSRV, inputHostName string) error { separatedInputDomain := strings.Split(strings.ToLower(inputHostName), ".") separatedRecord := strings.Split(strings.ToLower(recordFromSRV), ".") - if len(separatedRecord) < 2 { - return errors.New("DNS name must contain at least 2 labels") + if l := len(separatedInputDomain); l < 3 && len(separatedRecord) <= l { + return fmt.Errorf("DNS name must contain at least %d labels", l+1) } if len(separatedRecord) < len(separatedInputDomain) { return errors.New("Domain suffix from SRV record not matched input domain") } - inputDomainSuffix := separatedInputDomain[1:] - domainSuffixOffset := len(separatedRecord) - (len(separatedInputDomain) - 1) + inputDomainSuffix := separatedInputDomain + if len(inputDomainSuffix) > 2 { + inputDomainSuffix = inputDomainSuffix[1:] + } + domainSuffixOffset := len(separatedRecord) - len(inputDomainSuffix) recordDomainSuffix := separatedRecord[domainSuffixOffset:] for ix, label := range inputDomainSuffix {