Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add -4 and -6 flags #1984

Open
wants to merge 4 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions .golangci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -172,6 +172,8 @@ issues:
text: 'dnsTimeout is a global variable'
- path: challenge/dns01/nameserver_test.go
text: 'findXByFqdnTestCases is a global variable'
- path: challenge/dns01/network.go
text: 'currentNetworkStack is a global variable'
- path: challenge/http01/domain_matcher.go
text: 'string `Host` has \d occurrences, make it a constant'
- path: challenge/http01/domain_matcher.go
Expand Down
14 changes: 10 additions & 4 deletions challenge/dns01/nameserver.go
Original file line number Diff line number Diff line change
Expand Up @@ -265,7 +265,8 @@ func createDNSMsg(fqdn string, rtype uint16, recursive bool) *dns.Msg {

func sendDNSQuery(m *dns.Msg, ns string) (*dns.Msg, error) {
if ok, _ := strconv.ParseBool(os.Getenv("LEGO_EXPERIMENTAL_DNS_TCP_ONLY")); ok {
tcp := &dns.Client{Net: "tcp", Timeout: dnsTimeout}
network := currentNetworkStack.Network("tcp")
tcp := &dns.Client{Net: network, Timeout: dnsTimeout}
r, _, err := tcp.Exchange(m, ns)
if err != nil {
return r, &DNSError{Message: "DNS call error", MsgIn: m, NS: ns, Err: err}
Expand All @@ -274,11 +275,16 @@ func sendDNSQuery(m *dns.Msg, ns string) (*dns.Msg, error) {
return r, nil
}

udp := &dns.Client{Net: "udp", Timeout: dnsTimeout}
udpNetwork := currentNetworkStack.Network("udp")
udp := &dns.Client{Net: udpNetwork, Timeout: dnsTimeout}
r, _, err := udp.Exchange(m, ns)

if r != nil && r.Truncated {
tcp := &dns.Client{Net: "tcp", Timeout: dnsTimeout}
// We can encounter a net.OpError if the nameserver is not listening
// on UDP at all, i.e. net.Dial could not make a connection.
var opErr *net.OpError
if (r != nil && r.Truncated) || errors.As(err, &opErr) {
tcpNetwork := currentNetworkStack.Network("tcp")
tcp := &dns.Client{Net: tcpNetwork, Timeout: dnsTimeout}
// If the TCP request succeeds, the "err" will reset to nil
r, _, err = tcp.Exchange(m, ns)
}
Expand Down
136 changes: 134 additions & 2 deletions challenge/dns01/nameserver_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,134 @@ package dns01

import (
"errors"
"net"
"sort"
"sync"
"testing"
"time"

"github.com/miekg/dns"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)

func testDNSHandler(writer dns.ResponseWriter, reply *dns.Msg) {
msg := dns.Msg{}
msg.SetReply(reply)

if reply.Question[0].Qtype == dns.TypeA {
msg.Authoritative = true
domain := msg.Question[0].Name
msg.Answer = append(
msg.Answer,
&dns.A{
Hdr: dns.RR_Header{
Name: domain,
Rrtype: dns.TypeA,
Class: dns.ClassINET,
Ttl: 60,
},
A: net.IPv4(127, 0, 0, 1),
},
)
}

_ = writer.WriteMsg(&msg)
}

// getTestNameserver constructs a new DNS server on a local address, or set
// of addresses, that responds to an `A` query for `example.com`.
func getTestNameserver(t *testing.T, network string) *dns.Server {
t.Helper()
server := &dns.Server{
Handler: dns.HandlerFunc(testDNSHandler),
Net: network,
}
switch network {
case "tcp", "udp":
server.Addr = "0.0.0.0:0"
case "tcp4", "udp4":
server.Addr = "127.0.0.1:0"
case "tcp6", "udp6":
server.Addr = "[::1]:0"
}

waitLock := sync.Mutex{}
waitLock.Lock()
server.NotifyStartedFunc = waitLock.Unlock

go func() { _ = server.ListenAndServe() }()

waitLock.Lock()
return server
}

func startTestNameserver(t *testing.T, stack networkStack, proto string) (shutdown func(), addr string) {
t.Helper()
currentNetworkStack = stack
srv := getTestNameserver(t, currentNetworkStack.Network(proto))

shutdown = func() { _ = srv.Shutdown() }
if proto == "tcp" {
addr = srv.Listener.Addr().String()
} else {
addr = srv.PacketConn.LocalAddr().String()
}
return
}

func TestSendDNSQuery(t *testing.T) {
currentNameservers := recursiveNameservers

t.Cleanup(func() {
recursiveNameservers = currentNameservers
currentNetworkStack = dualStack
})

t.Run("does udp4 only", func(t *testing.T) {
stop, addr := startTestNameserver(t, ipv4only, "udp")
defer stop()

recursiveNameservers = ParseNameservers([]string{addr})
msg := createDNSMsg("example.com.", dns.TypeA, true)
result, queryError := sendDNSQuery(msg, addr)
require.NoError(t, queryError)
assert.Equal(t, "127.0.0.1", result.Answer[0].(*dns.A).A.String())
})

t.Run("does udp6 only", func(t *testing.T) {
stop, addr := startTestNameserver(t, ipv6only, "udp")
defer stop()

recursiveNameservers = ParseNameservers([]string{addr})
msg := createDNSMsg("example.com.", dns.TypeA, true)
result, queryError := sendDNSQuery(msg, addr)
require.NoError(t, queryError)
assert.Equal(t, "127.0.0.1", result.Answer[0].(*dns.A).A.String())
})

t.Run("does tcp4 and tcp6", func(t *testing.T) {
stop, addr := startTestNameserver(t, dualStack, "tcp")
host, port, _ := net.SplitHostPort(addr)
defer stop()
t.Logf("### port: %s", port)

addr6 := net.JoinHostPort(host, port)
recursiveNameservers = ParseNameservers([]string{addr6})
msg := createDNSMsg("example.com.", dns.TypeA, true)
result, queryError := sendDNSQuery(msg, addr6)
require.NoError(t, queryError)
assert.Equal(t, "127.0.0.1", result.Answer[0].(*dns.A).A.String())

addr4 := net.JoinHostPort("127.0.0.1", port)
recursiveNameservers = ParseNameservers([]string{addr4})
msg = createDNSMsg("example.com.", dns.TypeA, true)
result, queryError = sendDNSQuery(msg, addr4)
require.NoError(t, queryError)
assert.Equal(t, "127.0.0.1", result.Answer[0].(*dns.A).A.String())
})
}

func TestLookupNameserversOK(t *testing.T) {
testCases := []struct {
fqdn string
Expand Down Expand Up @@ -75,6 +195,7 @@ var findXByFqdnTestCases = []struct {
primaryNs string
nameservers []string
expectedError string
timeout time.Duration
}{
{
desc: "domain is a CNAME",
Expand Down Expand Up @@ -117,14 +238,18 @@ var findXByFqdnTestCases = []struct {
zone: "google.com.",
primaryNs: "ns1.google.com.",
nameservers: []string{":7053", ":8053", "8.8.8.8:53"},
timeout: 500 * time.Millisecond,
},
{
desc: "only non-existent nameservers",
fqdn: "mail.google.com.",
zone: "google.com.",
nameservers: []string{":7053", ":8053", ":9053"},
// use only the start of the message because the port changes with each call: 127.0.0.1:XXXXX->127.0.0.1:7053.
expectedError: "[fqdn=mail.google.com.] could not find the start of authority for 'mail.google.com.': DNS call error: read udp ",
// NOTE: On Windows, net.DialContext finds a way down to the ContectEx syscall.
// There a fault is marked as "connectex", not "connect", see
// https://cs.opensource.google/go/go/+/refs/tags/go1.19.5:src/net/fd_windows.go;l=112
expectedError: "could not find the start of authority for 'mail.google.com.':",
timeout: 500 * time.Millisecond,
},
{
desc: "no nameservers",
Expand Down Expand Up @@ -155,6 +280,11 @@ func TestFindZoneByFqdnCustom(t *testing.T) {
func TestFindPrimaryNsByFqdnCustom(t *testing.T) {
for _, test := range findXByFqdnTestCases {
t.Run(test.desc, func(t *testing.T) {
origTimeout := dnsTimeout
if test.timeout > 0 {
dnsTimeout = test.timeout
}

ClearFqdnCache()

ns, err := FindPrimaryNsByFqdnCustom(test.fqdn, test.nameservers)
Expand All @@ -165,6 +295,8 @@ func TestFindPrimaryNsByFqdnCustom(t *testing.T) {
require.NoError(t, err)
assert.Equal(t, test.primaryNs, ns)
}

dnsTimeout = origTimeout
})
}
}
Expand Down
41 changes: 41 additions & 0 deletions challenge/dns01/network.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
package dns01

// networkStack is used to indicate which IP stack should be used for DNS queries.
type networkStack int

const (
dualStack networkStack = iota
ipv4only
ipv6only
)

// currentNetworkStack is used to define which IP stack will be used. The default is
// both IPv4 and IPv6. Set to IPv4Only or IPv6Only to select either version.
var currentNetworkStack = dualStack

// Network interprets the NetworkStack setting in relation to the desired
// protocol. The proto value should be either "udp" or "tcp".
func (s networkStack) Network(proto string) string {
// The DNS client passes whatever value is set in (*dns.Client).Net to
// the [net.Dialer](https://github.com/miekg/dns/blob/fe20d5d/client.go#L119-L141).
// And the net.Dialer accepts strings such as "udp4" or "tcp6"
// (https://cs.opensource.google/go/go/+/refs/tags/go1.18.9:src/net/dial.go;l=167-182).
switch s {
case ipv4only:
return proto + "4"
case ipv6only:
return proto + "6"
default:
return proto
}
}

// SetIPv4Only forces DNS queries to only happen over the IPv4 stack.
func SetIPv4Only() { currentNetworkStack = ipv4only }

// SetIPv6Only forces DNS queries to only happen over the IPv6 stack.
func SetIPv6Only() { currentNetworkStack = ipv6only }

// SetDualStack indicates that both IPv4 and IPv6 should be allowed.
// This setting lets the OS determine which IP stack to use.
func SetDualStack() { currentNetworkStack = dualStack }
22 changes: 22 additions & 0 deletions challenge/http01/http_challenge_server.go
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,28 @@ func NewUnixProviderServer(socketPath string, mode fs.FileMode) *ProviderServer
return &ProviderServer{network: "unix", address: socketPath, socketMode: mode, matcher: &hostMatcher{}}
}

// SetIPv4Only starts the challenge server on an IPv4 address.
//
// Calling this method has no effect if s was created with NewUnixProviderServer.
func (s *ProviderServer) SetIPv4Only() { s.setTCPStack("tcp4") }

// SetIPv6Only starts the challenge server on an IPv6 address.
//
// Calling this method has no effect if s was created with NewUnixProviderServer.
func (s *ProviderServer) SetIPv6Only() { s.setTCPStack("tcp6") }

// SetDualStack indicates that both IPv4 and IPv6 should be allowed.
// This setting lets the OS determine which IP stack to use for the challenge server.
//
// Calling this method has no effect if s was created with NewUnixProviderServer.
func (s *ProviderServer) SetDualStack() { s.setTCPStack("tcp") }

func (s *ProviderServer) setTCPStack(network string) {
if s.network != "unix" {
s.network = network
}
}

// Present starts a web server and makes the token available at `ChallengePath(token)` for web requests.
func (s *ProviderServer) Present(domain, token, keyAuth string) error {
var err error
Expand Down
17 changes: 17 additions & 0 deletions challenge/http01/http_challenge_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ func TestProviderServer_GetAddress(t *testing.T) {
testCases := []struct {
desc string
server *ProviderServer
network func(server *ProviderServer)
expected string
}{
{
Expand All @@ -49,6 +50,18 @@ func TestProviderServer_GetAddress(t *testing.T) {
server: NewProviderServer("localhost", "8080"),
expected: "localhost:8080",
},
{
desc: "TCP4 with host and port",
server: NewProviderServer("localhost", "8080"),
network: func(s *ProviderServer) { s.SetIPv4Only() },
expected: "localhost:8080",
},
{
desc: "TCP6 with host and port",
server: NewProviderServer("localhost", "8080"),
network: func(s *ProviderServer) { s.SetIPv6Only() },
expected: "localhost:8080",
},
{
desc: "UDS socket",
server: NewUnixProviderServer(sock, fs.ModeSocket|0o666),
Expand All @@ -60,6 +73,10 @@ func TestProviderServer_GetAddress(t *testing.T) {
t.Run(test.desc, func(t *testing.T) {
t.Parallel()

if test.network != nil {
test.network(test.server)
}

address := test.server.GetAddress()
assert.Equal(t, test.expected, address)
})
Expand Down
18 changes: 16 additions & 2 deletions challenge/tlsalpn01/tls_alpn_challenge_server.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,16 +26,30 @@ const (
type ProviderServer struct {
iface string
port string
network string
listener net.Listener
}

// NewProviderServer creates a new ProviderServer on the selected interface and port.
// Setting iface and / or port to an empty string will make the server fall back to
// the "any" interface and port 443 respectively.
func NewProviderServer(iface, port string) *ProviderServer {
return &ProviderServer{iface: iface, port: port}
if port == "" {
port = defaultTLSPort
}
return &ProviderServer{iface: iface, port: port, network: "tcp"}
}

// SetIPv4Only starts the challenge server on an IPv4 address.
func (s *ProviderServer) SetIPv4Only() { s.network = "tcp4" }

// SetIPv6Only starts the challenge server on an IPv6 address.
func (s *ProviderServer) SetIPv6Only() { s.network = "tcp6" }

// SetDualStack indicates that both IPv4 and IPv6 should be allowed.
// This setting lets the OS determine which IP stack to use for the challenge server.
func (s *ProviderServer) SetDualStack() { s.network = "tcp" }

func (s *ProviderServer) GetAddress() string {
return net.JoinHostPort(s.iface, s.port)
}
Expand Down Expand Up @@ -65,7 +79,7 @@ func (s *ProviderServer) Present(domain, token, keyAuth string) error {
tlsConf.NextProtos = []string{ACMETLS1Protocol}

// Create the listener with the created tls.Config.
s.listener, err = tls.Listen("tcp", s.GetAddress(), tlsConf)
s.listener, err = tls.Listen(s.network, s.GetAddress(), tlsConf)
if err != nil {
return fmt.Errorf("could not start HTTPS server for challenge: %w", err)
}
Expand Down
Loading