Skip to content

Commit

Permalink
socket: clean up internal control context use
Browse files Browse the repository at this point in the history
Signed-off-by: Matt Layher <[email protected]>
  • Loading branch information
mdlayher committed Aug 30, 2023
1 parent 749bf3b commit 8e65586
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 34 deletions.
39 changes: 14 additions & 25 deletions conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -440,9 +440,7 @@ func (c *Conn) Accept(ctx context.Context, flags int) (*Conn, unix.Sockaddr, err

// Bind wraps bind(2).
func (c *Conn) Bind(sa unix.Sockaddr) error {
return c.control(context.Background(), "bind", func(fd int) error {
return unix.Bind(fd, sa)
})
return c.control("bind", func(fd int) error { return unix.Bind(fd, sa) })
}

// Connect wraps connect(2). In order to verify that the underlying socket is
Expand Down Expand Up @@ -530,40 +528,38 @@ func (c *Conn) Connect(ctx context.Context, sa unix.Sockaddr) (unix.Sockaddr, er

// Getsockname wraps getsockname(2).
func (c *Conn) Getsockname() (unix.Sockaddr, error) {
return controlT(c, context.Background(), "getsockname", unix.Getsockname)
return controlT(c, "getsockname", unix.Getsockname)
}

// Getpeername wraps getpeername(2).
func (c *Conn) Getpeername() (unix.Sockaddr, error) {
return controlT(c, context.Background(), "getpeername", unix.Getpeername)
return controlT(c, "getpeername", unix.Getpeername)
}

// GetsockoptICMPv6Filter wraps getsockopt(2) for *unix.ICMPv6Filter values.
func (c *Conn) GetsockoptICMPv6Filter(level, opt int) (*unix.ICMPv6Filter, error) {
return controlT(c, context.Background(), "getsockopt", func(fd int) (*unix.ICMPv6Filter, error) {
return controlT(c, "getsockopt", func(fd int) (*unix.ICMPv6Filter, error) {
return unix.GetsockoptICMPv6Filter(fd, level, opt)
})
}

// GetsockoptInt wraps getsockopt(2) for integer values.
func (c *Conn) GetsockoptInt(level, opt int) (int, error) {
return controlT(c, context.Background(), "getsockopt", func(fd int) (int, error) {
return controlT(c, "getsockopt", func(fd int) (int, error) {
return unix.GetsockoptInt(fd, level, opt)
})
}

// GetsockoptString wraps getsockopt(2) for string values.
func (c *Conn) GetsockoptString(level, opt int) (string, error) {
return controlT(c, context.Background(), "getsockopt", func(fd int) (string, error) {
return controlT(c, "getsockopt", func(fd int) (string, error) {
return unix.GetsockoptString(fd, level, opt)
})
}

// Listen wraps listen(2).
func (c *Conn) Listen(n int) error {
return c.control(context.Background(), "listen", func(fd int) error {
return unix.Listen(fd, n)
})
return c.control("listen", func(fd int) error { return unix.Listen(fd, n) })
}

// Recvmsg wraps recvmsg(2).
Expand Down Expand Up @@ -618,30 +614,28 @@ func (c *Conn) Sendto(ctx context.Context, p []byte, flags int, to unix.Sockaddr

// SetsockoptICMPv6Filter wraps setsockopt(2) for *unix.ICMPv6Filter values.
func (c *Conn) SetsockoptICMPv6Filter(level, opt int, filter *unix.ICMPv6Filter) error {
return c.control(context.Background(), "setsockopt", func(fd int) error {
return c.control("setsockopt", func(fd int) error {
return unix.SetsockoptICMPv6Filter(fd, level, opt, filter)
})
}

// SetsockoptInt wraps setsockopt(2) for integer values.
func (c *Conn) SetsockoptInt(level, opt, value int) error {
return c.control(context.Background(), "setsockopt", func(fd int) error {
return c.control("setsockopt", func(fd int) error {
return unix.SetsockoptInt(fd, level, opt, value)
})
}

// SetsockoptString wraps setsockopt(2) for string values.
func (c *Conn) SetsockoptString(level, opt int, value string) error {
return c.control(context.Background(), "setsockopt", func(fd int) error {
return c.control("setsockopt", func(fd int) error {
return unix.SetsockoptString(fd, level, opt, value)
})
}

// Shutdown wraps shutdown(2).
func (c *Conn) Shutdown(how int) error {
return c.control(context.Background(), "shutdown", func(fd int) error {
return unix.Shutdown(fd, how)
})
return c.control("shutdown", func(fd int) error { return unix.Shutdown(fd, how) })
}

// Conn low-level read/write/control functions. These functions mirror the
Expand Down Expand Up @@ -830,16 +824,16 @@ func rwT[T any](c *Conn, rw rwContext[T]) (T, error) {
}

// control executes Conn.control for op using the input function.
func (c *Conn) control(ctx context.Context, op string, f func(fd int) error) error {
_, err := controlT(c, ctx, op, func(fd int) (struct{}, error) {
func (c *Conn) control(op string, f func(fd int) error) error {
_, err := controlT(c, op, func(fd int) (struct{}, error) {
return struct{}{}, f(fd)
})
return err
}

// controlT executes c.rc.Control for op using the input function, returning a
// newly allocated result T.
func controlT[T any](c *Conn, ctx context.Context, op string, f func(fd int) (T, error)) (T, error) {
func controlT[T any](c *Conn, op string, f func(fd int) (T, error)) (T, error) {
if atomic.LoadUint32(&c.closed) != 0 {
// If the file descriptor is already closed, do nothing.
return *new(T), os.NewSyscallError(op, unix.EBADF)
Expand All @@ -857,11 +851,6 @@ func controlT[T any](c *Conn, ctx context.Context, op string, f func(fd int) (T,
// The last values for t and err are captured outside of the closure for
// use when the loop breaks.
for {
if err = ctx.Err(); err != nil {
// Early exit due to context cancel.
return
}

t, err = f(int(fd))
if ready(err) {
return
Expand Down
18 changes: 9 additions & 9 deletions conn_linux.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ import (
// IoctlKCMClone wraps ioctl(2) for unix.KCMClone values, but returns a Conn
// rather than a raw file descriptor.
func (c *Conn) IoctlKCMClone() (*Conn, error) {
info, err := controlT(c, context.Background(), "ioctl", unix.IoctlKCMClone)
info, err := controlT(c, "ioctl", unix.IoctlKCMClone)
if err != nil {
return nil, err
}
Expand All @@ -26,22 +26,22 @@ func (c *Conn) IoctlKCMClone() (*Conn, error) {

// IoctlKCMAttach wraps ioctl(2) for unix.KCMAttach values.
func (c *Conn) IoctlKCMAttach(info unix.KCMAttach) error {
return c.control(context.Background(), "ioctl", func(fd int) error {
return c.control("ioctl", func(fd int) error {
return unix.IoctlKCMAttach(fd, info)
})
}

// IoctlKCMUnattach wraps ioctl(2) for unix.KCMUnattach values.
func (c *Conn) IoctlKCMUnattach(info unix.KCMUnattach) error {
return c.control(context.Background(), "ioctl", func(fd int) error {
return c.control("ioctl", func(fd int) error {
return unix.IoctlKCMUnattach(fd, info)
})
}

// PidfdGetfd wraps pidfd_getfd(2) for a Conn which wraps a pidfd, but returns a
// Conn rather than a raw file descriptor.
func (c *Conn) PidfdGetfd(targetFD, flags int) (*Conn, error) {
outFD, err := controlT(c, context.Background(), "pidfd_getfd", func(fd int) (int, error) {
outFD, err := controlT(c, "pidfd_getfd", func(fd int) (int, error) {
return unix.PidfdGetfd(fd, targetFD, flags)
})
if err != nil {
Expand All @@ -55,7 +55,7 @@ func (c *Conn) PidfdGetfd(targetFD, flags int) (*Conn, error) {
// PidfdSendSignal wraps pidfd_send_signal(2) for a Conn which wraps a Linux
// pidfd.
func (c *Conn) PidfdSendSignal(sig unix.Signal, info *unix.Siginfo, flags int) error {
return c.control(context.Background(), "pidfd_send_signal", func(fd int) error {
return c.control("pidfd_send_signal", func(fd int) error {
return unix.PidfdSendSignal(fd, sig, info, flags)
})
}
Expand Down Expand Up @@ -84,28 +84,28 @@ func (c *Conn) RemoveBPF() error {

// SetsockoptPacketMreq wraps setsockopt(2) for unix.PacketMreq values.
func (c *Conn) SetsockoptPacketMreq(level, opt int, mreq *unix.PacketMreq) error {
return c.control(context.Background(), "setsockopt", func(fd int) error {
return c.control("setsockopt", func(fd int) error {
return unix.SetsockoptPacketMreq(fd, level, opt, mreq)
})
}

// SetsockoptSockFprog wraps setsockopt(2) for unix.SockFprog values.
func (c *Conn) SetsockoptSockFprog(level, opt int, fprog *unix.SockFprog) error {
return c.control(context.Background(), "setsockopt", func(fd int) error {
return c.control("setsockopt", func(fd int) error {
return unix.SetsockoptSockFprog(fd, level, opt, fprog)
})
}

// GetsockoptTpacketStats wraps getsockopt(2) for unix.TpacketStats values.
func (c *Conn) GetsockoptTpacketStats(level, name int) (*unix.TpacketStats, error) {
return controlT(c, context.Background(), "getsockopt", func(fd int) (*unix.TpacketStats, error) {
return controlT(c, "getsockopt", func(fd int) (*unix.TpacketStats, error) {
return unix.GetsockoptTpacketStats(fd, level, name)
})
}

// GetsockoptTpacketStatsV3 wraps getsockopt(2) for unix.TpacketStatsV3 values.
func (c *Conn) GetsockoptTpacketStatsV3(level, name int) (*unix.TpacketStatsV3, error) {
return controlT(c, context.Background(), "getsockopt", func(fd int) (*unix.TpacketStatsV3, error) {
return controlT(c, "getsockopt", func(fd int) (*unix.TpacketStatsV3, error) {
return unix.GetsockoptTpacketStatsV3(fd, level, name)
})
}
Expand Down

0 comments on commit 8e65586

Please sign in to comment.