Skip to content

Commit

Permalink
addrinfo: use default resolver (#70)
Browse files Browse the repository at this point in the history
This PR modifies `wasirun` to use `net.DefaultResolver` so we can
configure the DNS servers for `SockAddressInfo` and propagate the
context for asynchronous cancellation.

---------

Signed-off-by: Achille Roussel <[email protected]>
  • Loading branch information
achille-roussel committed Jun 25, 2023
1 parent f05f093 commit d1f5b98
Show file tree
Hide file tree
Showing 4 changed files with 68 additions and 41 deletions.
30 changes: 27 additions & 3 deletions cmd/wasirun/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"context"
"flag"
"fmt"
"net"
"net/http"
_ "net/http/pprof"
"os"
Expand Down Expand Up @@ -33,20 +34,23 @@ OPTIONS:
--dir <DIR>
Grant access to the specified host directory
--listen <ADDR>
--listen <ADDR:PORT>
Grant access to a socket listening on the specified address
--dial <ADDR>
--dial <ADDR:PORT>
Grant access to a socket connected to the specified address
--dns-server <ADDR:PORT>
Sets the address of the DNS server to use for name resolution
--env <NAME=VAL>
Pass an environment variable to the module
--sockets <NAME>
Enable a sockets extension, either {none, auto, path_open,
wasmedgev1, wasmedgev2}
--pprof-addr <ADDR>
--pprof-addr <ADDR:PORT>
Start a pprof server listening on the specified address
--trace
Expand All @@ -68,6 +72,7 @@ var (
dirs stringList
listens stringList
dials stringList
dnsServer string
socketExt string
pprofAddr string
trace bool
Expand All @@ -83,6 +88,7 @@ func main() {
flagSet.Var(&dirs, "dir", "")
flagSet.Var(&listens, "listen", "")
flagSet.Var(&dials, "dial", "")
flagSet.StringVar(&dnsServer, "dns-server", "", "")
flagSet.StringVar(&socketExt, "sockets", "auto", "")
flagSet.StringVar(&pprofAddr, "pprof-addr", "", "")
flagSet.BoolVar(&trace, "trace", false, "")
Expand All @@ -106,6 +112,24 @@ func main() {
os.Exit(1)
}

if dnsServer != "" {
_, dnsServerPort, _ := net.SplitHostPort(dnsServer)
net.DefaultResolver.PreferGo = true
net.DefaultResolver.Dial = func(ctx context.Context, network, address string) (net.Conn, error) {
var d net.Dialer
if dnsServerPort != "" {
address = dnsServer
} else {
_, port, err := net.SplitHostPort(address)
if err != nil {
return nil, net.InvalidAddrError(address)
}
address = net.JoinHostPort(dnsServer, port)
}
return d.DialContext(ctx, network, address)
}
}

if err := run(args[0], args[1:]); err != nil {
if exitErr, ok := err.(*sys.ExitError); ok {
os.Exit(int(exitErr.ExitCode()))
Expand Down
2 changes: 1 addition & 1 deletion go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -8,4 +8,4 @@ require (
golang.org/x/sys v0.8.0
)

require golang.org/x/exp v0.0.0-20230522175609-2e198f4a06a1 // indirect
require golang.org/x/exp v0.0.0-20230522175609-2e198f4a06a1
10 changes: 4 additions & 6 deletions imports/wasi_snapshot_preview1/wasmedge.go
Original file line number Diff line number Diff line change
Expand Up @@ -276,9 +276,8 @@ func (m *Module) WasmEdgeSockAddrInfo(ctx context.Context, name String, service
// but then doesn't set ai_canonnamelen... Argh.
mem := resPtrPtr.Memory()
resPtr := resPtrPtr.Load()
results := m.addrinfo[:n]
count := 0
for {
for _, addrinfo := range m.addrinfo[:n] {
res := resPtr.Load()
if res.Address == 0 {
return Errno(wasi.EFAULT)
Expand All @@ -304,7 +303,7 @@ func (m *Module) WasmEdgeSockAddrInfo(ctx context.Context, name String, service
if !ok {
return Errno(wasi.EFAULT)
}
switch addr := results[0].Address.(type) {
switch addr := addrinfo.Address.(type) {
case *wasi.Inet4Address:
if len(addrData) < 6 {
return Errno(wasi.EFAULT)
Expand All @@ -325,11 +324,10 @@ func (m *Module) WasmEdgeSockAddrInfo(ctx context.Context, name String, service
res.CanonicalNameLength = 0 // Not yet supported
resPtr.Store(res)
count++
results = results[1:]
if res.Next == 0 || len(results) == 0 {
if res.Next == 0 {
break
}
resPtr = Ptr[wasmEdgeAddressInfo](resPtr.Memory(), res.Next)
resPtr = Ptr[wasmEdgeAddressInfo](mem, res.Next)
}
resLengthPtr.Store(Uint32(count))
return Errno(wasi.ESUCCESS)
Expand Down
67 changes: 36 additions & 31 deletions systems/unix/system.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@ import (
"time"

"github.com/stealthrocket/wasi-go"
"golang.org/x/exp/slices"
"golang.org/x/sys/unix"
)

Expand Down Expand Up @@ -967,7 +966,7 @@ func (s *System) SockAddressInfo(ctx context.Context, name, service string, hint
if hints.Flags.Has(wasi.NumericService) {
port, err = strconv.Atoi(service)
} else {
port, err = net.LookupPort(network, service)
port, err = net.DefaultResolver.LookupPort(ctx, network, service)
}
if err != nil || port < 0 || port > 65535 {
return 0, wasi.EINVAL // EAI_NONAME / EAI_SERVICE
Expand All @@ -989,55 +988,61 @@ func (s *System) SockAddressInfo(ctx context.Context, name, service string, hint
ip = net.IPv4zero
}
}
if ip != nil {
results = results[:1]
results[0] = wasi.AddressInfo{}

makeAddressInfo := func(ip net.IP, port int) wasi.AddressInfo {
addrInfo := wasi.AddressInfo{
Flags: hints.Flags,
SocketType: hints.SocketType,
Protocol: hints.Protocol,
}
if ipv4 := ip.To4(); ipv4 != nil {
inet4Addr := &wasi.Inet4Address{Port: port}
copy(inet4Addr.Addr[:], ipv4)
results[0].Address = inet4Addr
addrInfo.Family = wasi.InetFamily
addrInfo.Address = inet4Addr
} else {
inet6Addr := &wasi.Inet6Address{Port: port}
copy(inet6Addr.Addr[:], ip)
results[0].Address = inet6Addr
addrInfo.Family = wasi.Inet6Family
addrInfo.Address = inet6Addr
}
return addrInfo
}

if ip != nil {
results[0] = makeAddressInfo(ip, port)
return 1, wasi.ESUCCESS
}

ips, err := net.LookupIP(name)
// LookupIP requires the network to be one of "ip", "ip4", or "ip6".
switch network {
case "tcp", "udp":
network = "ip"
case "tcp4", "udp4":
network = "ip4"
case "tcp6", "udp6":
network = "ip6"
}

ips, err := net.DefaultResolver.LookupIP(ctx, network, name)
if err != nil {
return 0, wasi.ECANCELED // TODO: better errors on name resolution failure
}

addrs := make([]wasi.AddressInfo, 0, 16)
addrs4 := make([]wasi.AddressInfo, 0, 8)
addrs6 := make([]wasi.AddressInfo, 0, 8)

for _, ip := range ips {
var addr wasi.AddressInfo
if ipv4 := ip.To4(); ipv4 != nil {
if hints.Family == wasi.Inet6Family {
continue
}
inet4Addr := wasi.Inet4Address{Port: port}
copy(inet4Addr.Addr[:], ip)
addr.Family = wasi.InetFamily
addr.Address = &inet4Addr
if ip.To4() != nil {
addrs4 = append(addrs4, makeAddressInfo(ip, port))
} else {
if hints.Family == wasi.InetFamily {
continue
}
inet6Addr := wasi.Inet6Address{Port: port}
copy(inet6Addr.Addr[:], ip)
addr.Family = wasi.Inet6Family
addr.Address = &inet6Addr
addrs6 = append(addrs6, makeAddressInfo(ip, port))
}
addrs = append(addrs, addr)
}

slices.SortStableFunc(addrs, func(a1, a2 wasi.AddressInfo) bool {
return a1.Family < a2.Family
})

return copy(results, addrs), wasi.ESUCCESS
n := copy(results[0:], addrs4)
n += copy(results[n:], addrs6)
return n, wasi.ESUCCESS
}

func (s *System) Close(ctx context.Context) error {
Expand Down

0 comments on commit d1f5b98

Please sign in to comment.