Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add Unix forwarding server implementations #196

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion options_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ func TestPasswordAuth(t *testing.T) {

func TestPasswordAuthBadPass(t *testing.T) {
t.Parallel()
l := newLocalListener()
l := newLocalTCPListener()
srv := &Server{Handler: func(s Session) {}}
srv.SetOption(PasswordAuth(func(ctx Context, password string) bool {
return false
Expand Down
2 changes: 2 additions & 0 deletions server.go
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,8 @@ type Server struct {
ConnCallback ConnCallback // optional callback for wrapping net.Conn before handling
LocalPortForwardingCallback LocalPortForwardingCallback // callback for allowing local port forwarding, denies all if nil
ReversePortForwardingCallback ReversePortForwardingCallback // callback for allowing reverse port forwarding, denies all if nil
LocalUnixForwardingCallback LocalUnixForwardingCallback // callback for allowing local unix forwarding ([email protected]), denies all if nil
ReverseUnixForwardingCallback ReverseUnixForwardingCallback // callback for allowing reverse unix forwarding ([email protected]), denies all if nil
ServerConfigCallback ServerConfigCallback // callback for configuring detailed SSH options
SessionRequestCallback SessionRequestCallback // callback for allowing or denying SSH sessions

Expand Down
4 changes: 2 additions & 2 deletions server_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ func TestAddHostKey(t *testing.T) {
}

func TestServerShutdown(t *testing.T) {
l := newLocalListener()
l := newLocalTCPListener()
testBytes := []byte("Hello world\n")
s := &Server{
Handler: func(s Session) {
Expand Down Expand Up @@ -80,7 +80,7 @@ func TestServerShutdown(t *testing.T) {
}

func TestServerClose(t *testing.T) {
l := newLocalListener()
l := newLocalTCPListener()
s := &Server{
Handler: func(s Session) {
time.Sleep(5 * time.Second)
Expand Down
19 changes: 15 additions & 4 deletions session_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,14 +20,25 @@ func (srv *Server) serveOnce(l net.Listener) error {
return e
}
srv.ChannelHandlers = map[string]ChannelHandler{
"session": DefaultSessionHandler,
"direct-tcpip": DirectTCPIPHandler,
"session": DefaultSessionHandler,
"direct-tcpip": DirectTCPIPHandler,
"[email protected]": DirectStreamLocalHandler,
}

forwardedTCPHandler := &ForwardedTCPHandler{}
forwardedUnixHandler := &ForwardedUnixHandler{}
srv.RequestHandlers = map[string]RequestHandler{
"tcpip-forward": forwardedTCPHandler.HandleSSHRequest,
"cancel-tcpip-forward": forwardedTCPHandler.HandleSSHRequest,
"[email protected]": forwardedUnixHandler.HandleSSHRequest,
"[email protected]": forwardedUnixHandler.HandleSSHRequest,
}

srv.HandleConn(conn)
return nil
}

func newLocalListener() net.Listener {
func newLocalTCPListener() net.Listener {
l, err := net.Listen("tcp", "127.0.0.1:0")
if err != nil {
if l, err = net.Listen("tcp6", "[::1]:0"); err != nil {
Expand Down Expand Up @@ -64,7 +75,7 @@ func newClientSession(t *testing.T, addr string, config *gossh.ClientConfig) (*g
}

func newTestSession(t *testing.T, srv *Server, cfg *gossh.ClientConfig) (*gossh.Session, *gossh.Client, func()) {
l := newLocalListener()
l := newLocalTCPListener()
go srv.serveOnce(l)
return newClientSession(t, l.Addr().String(), cfg)
}
Expand Down
20 changes: 20 additions & 0 deletions ssh.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package ssh

import (
"crypto/subtle"
"errors"
"net"

gossh "golang.org/x/crypto/ssh"
Expand Down Expand Up @@ -29,6 +30,9 @@ const (
// DefaultHandler is the default Handler used by Serve.
var DefaultHandler Handler

// ErrReject is returned by some callbacks to reject a request.
var ErrRejected = errors.New("ssh: rejected")

// Option is a functional option handler for Server.
type Option func(*Server) error

Expand Down Expand Up @@ -64,6 +68,22 @@ type LocalPortForwardingCallback func(ctx Context, destinationHost string, desti
// ReversePortForwardingCallback is a hook for allowing reverse port forwarding
type ReversePortForwardingCallback func(ctx Context, bindHost string, bindPort uint32) bool

// LocalUnixForwardingCallback is a hook for allowing unix forwarding
// ([email protected]). Returning ErrRejected will reject the
// request. The returned net.Conn will be closed by the server when no longer
// needed.
//
// Use SimpleUnixLocalForwardingCallback for a basic implementation.
type LocalUnixForwardingCallback func(ctx Context, socketPath string) (net.Conn, error)

// ReverseUnixForwardingCallback is a hook for allowing reverse unix forwarding
// ([email protected]). Returning ErrRejected will reject the
// request. The returned net.Listener will be closed by the server when no
// longer needed.
//
// Use SimpleUnixReverseForwardingCallback for a basic implementation.
type ReverseUnixForwardingCallback func(ctx Context, socketPath string) (net.Listener, error)

// ServerConfigCallback is a hook for creating custom default server configs
type ServerConfigCallback func(ctx Context) *gossh.ServerConfig

Expand Down
252 changes: 252 additions & 0 deletions streamlocal.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,252 @@
package ssh

import (
"context"
"errors"
"fmt"
"io/fs"
"net"
"os"
"path/filepath"
"sync"
"syscall"

gossh "golang.org/x/crypto/ssh"
)

const (
forwardedUnixChannelType = "[email protected]"
)

// directStreamLocalChannelData data struct as specified in OpenSSH's protocol
// extensions document, Section 2.4.
// https://cvsweb.openbsd.org/src/usr.bin/ssh/PROTOCOL?annotate=HEAD
type directStreamLocalChannelData struct {
SocketPath string

Reserved1 string
Reserved2 uint32
}

// DirectStreamLocalHandler provides Unix forwarding from client -> server. It
// can be enabled by adding it to the server's ChannelHandlers under
// `[email protected]`.
//
// Unix socket support on Windows is not widely available, so this handler may
// not work on all Windows installations and is not tested on Windows.
func DirectStreamLocalHandler(srv *Server, _ *gossh.ServerConn, newChan gossh.NewChannel, ctx Context) {
var d directStreamLocalChannelData
err := gossh.Unmarshal(newChan.ExtraData(), &d)
if err != nil {
_ = newChan.Reject(gossh.ConnectionFailed, "error parsing direct-streamlocal data: "+err.Error())
return
}

if srv.LocalUnixForwardingCallback == nil {
_ = newChan.Reject(gossh.Prohibited, "unix forwarding is disabled")
return
}
dconn, err := srv.LocalUnixForwardingCallback(ctx, d.SocketPath)
if err != nil {
if errors.Is(err, ErrRejected) {
_ = newChan.Reject(gossh.Prohibited, "unix forwarding is disabled")
return
}
_ = newChan.Reject(gossh.ConnectionFailed, fmt.Sprintf("dial unix socket %q: %+v", d.SocketPath, err.Error()))
return
}

ch, reqs, err := newChan.Accept()
if err != nil {
_ = dconn.Close()
return
}
go gossh.DiscardRequests(reqs)

bicopy(ctx, ch, dconn)
}

// remoteUnixForwardRequest describes the extra data sent in a
// [email protected] containing the socket path to bind to.
type remoteUnixForwardRequest struct {
SocketPath string
}

// remoteUnixForwardChannelData describes the data sent as the payload in the new
// channel request when a Unix connection is accepted by the listener.
type remoteUnixForwardChannelData struct {
SocketPath string
Reserved uint32
}

// ForwardedUnixHandler can be enabled by creating a ForwardedUnixHandler and
// adding the HandleSSHRequest callback to the server's RequestHandlers under
// `[email protected]` and
// `[email protected]`
//
// Unix socket support on Windows is not widely available, so this handler may
// not work on all Windows installations and is not tested on Windows.
type ForwardedUnixHandler struct {
sync.Mutex
forwards map[string]net.Listener
}

func (h *ForwardedUnixHandler) HandleSSHRequest(ctx Context, srv *Server, req *gossh.Request) (bool, []byte) {
h.Lock()
if h.forwards == nil {
h.forwards = make(map[string]net.Listener)
}
h.Unlock()
conn, ok := ctx.Value(ContextKeyConn).(*gossh.ServerConn)
if !ok {
// TODO: log cast failure
return false, nil
}

switch req.Type {
case "[email protected]":
var reqPayload remoteUnixForwardRequest
err := gossh.Unmarshal(req.Payload, &reqPayload)
if err != nil {
// TODO: log parse failure
return false, nil
}

if srv.ReverseUnixForwardingCallback == nil {
return false, []byte("unix forwarding is disabled")
}

addr := reqPayload.SocketPath
h.Lock()
_, ok := h.forwards[addr]
h.Unlock()
if ok {
// TODO: log failure
return false, nil
}
Comment on lines +119 to +126

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Granted, we haven't introduced a similar option to SSH StreamLocalBindUnlink here, and the default is no for OpenSSH, but in coder/coder we want its yes-behavior. (And my personal opinion is that it's the more useful default.)

If yes, the same socket can be forwarded by multiple sessions, in this case each session should maintain the connections that were opened while it was active. Otherwise the following scenario can't be supported:

  1. Open SSH connection (a) with forwarded socket (/tmp/my.sock)
  2. (a) Open /tmp/my.sock
  3. Open SSH connection (b)
  4. (b) Open /tmp/my.sock
  5. Open SSH connection (c) with forwarded socket (/tmp/my.sock, overwrite)
  6. (c) Open /tmp/my.sock
  7. Close connection (a)
  8. (a) closed (b) Socket closed (c) Socket remains open

In the current implementation, (c) socket can't remain open as the socket wasn't overwritten.

Also: Consider returning true here instead as a default (see motivation in below link).

See: https://github.com/coder/coder/blob/b828412edd913bef6665cf8a0b2ca7ac93334012/agent/agentssh/forward.go#L76-L91


ln, err := srv.ReverseUnixForwardingCallback(ctx, addr)
if err != nil {
if errors.Is(err, ErrRejected) {
return false, []byte("unix forwarding is disabled")
}
// TODO: log unix listen failure
return false, nil
}

// The listener needs to successfully start before it can be added to
// the map, so we don't have to worry about checking for an existing
// listener as you can't listen on the same socket twice.
//
// This is also what the TCP version of this code does.
h.Lock()
h.forwards[addr] = ln
h.Unlock()

ctx, cancel := context.WithCancel(ctx)
go func() {
<-ctx.Done()
_ = ln.Close()
}()
go func() {
defer cancel()

for {
c, err := ln.Accept()
if err != nil {
// closed below
break
}
payload := gossh.Marshal(&remoteUnixForwardChannelData{
SocketPath: addr,
})

go func() {
ch, reqs, err := conn.OpenChannel(forwardedUnixChannelType, payload)
if err != nil {
_ = c.Close()
return
}
go gossh.DiscardRequests(reqs)
bicopy(ctx, ch, c)
}()
}

h.Lock()
ln2, ok := h.forwards[addr]
if ok && ln2 == ln {
delete(h.forwards, addr)
}
h.Unlock()
_ = ln.Close()
}()

return true, nil

case "[email protected]":
var reqPayload remoteUnixForwardRequest
err := gossh.Unmarshal(req.Payload, &reqPayload)
if err != nil {
// TODO: log parse failure
return false, nil
}
h.Lock()
ln, ok := h.forwards[reqPayload.SocketPath]
h.Unlock()
if ok {
_ = ln.Close()
}
return true, nil

default:
return false, nil
}
}

// unlink removes files and unlike os.Remove, directories are kept.
func unlink(path string) error {
// Ignore EINTR like os.Remove, see ignoringEINTR in os/file_posix.go
// for more details.
for {
err := syscall.Unlink(path)
if !errors.Is(err, syscall.EINTR) {
return err
}
}
}

// SimpleUnixLocalForwardingCallback provides a basic implementation for
// LocalUnixForwardingCallback. It will simply dial the requested socket using
// a context-aware dialer.
func SimpleUnixLocalForwardingCallback(ctx Context, socketPath string) (net.Conn, error) {
var d net.Dialer
return d.DialContext(ctx, "unix", socketPath)
}

// SimpleUnixReverseForwardingCallback provides a basic implementation for
// ReverseUnixForwardingCallback. The parent directory will be created (with
// os.MkdirAll), and existing files with the same name will be removed.
func SimpleUnixReverseForwardingCallback(_ Context, socketPath string) (net.Listener, error) {
// Create socket parent dir if not exists.
parentDir := filepath.Dir(socketPath)
err := os.MkdirAll(parentDir, 0700)
if err != nil {
return nil, fmt.Errorf("failed to create parent directory %q for socket %q: %w", parentDir, socketPath, err)
}

// Remove existing socket if it exists. We do not use os.Remove() here
// so that directories are kept. Note that it's possible that we will
// overwrite a regular file here. Both of these behaviors match OpenSSH,
// however, which is why we unlink.
err = unlink(socketPath)
if err != nil && !errors.Is(err, fs.ErrNotExist) {
return nil, fmt.Errorf("failed to remove existing file in socket path %q: %w", socketPath, err)
}

ln, err := net.Listen("unix", socketPath)

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We use context aware dialer in the other callback, should we use (&net.ListenConfig{}).Listen(ctx, "unix", addr) here as well?

if err != nil {
return nil, fmt.Errorf("failed to listen on unix socket %q: %w", socketPath, err)
}

return ln, err
}
Loading