Skip to content

Commit 20a4547

Browse files
notnoopciprogrium
authored andcommitted
Support for local port forwarding (gliderlabs#38)
* Support local port forwarding * refactor testSession to return ssh client as well * Tests for local port forwarding
1 parent 1051a0d commit 20a4547

File tree

7 files changed

+166
-20
lines changed

7 files changed

+166
-20
lines changed

context_test.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ func TestSetPermissions(t *testing.T) {
77
permsExt := map[string]string{
88
"foo": "bar",
99
}
10-
session, cleanup := newTestSessionWithOptions(t, &Server{
10+
session, _, cleanup := newTestSessionWithOptions(t, &Server{
1111
Handler: func(s Session) {
1212
if _, ok := s.Permissions().Extensions["foo"]; !ok {
1313
t.Fatalf("got %#v; want %#v", s.Permissions().Extensions, permsExt)
@@ -29,7 +29,7 @@ func TestSetValue(t *testing.T) {
2929
"foo": "bar",
3030
}
3131
key := "testValue"
32-
session, cleanup := newTestSessionWithOptions(t, &Server{
32+
session, _, cleanup := newTestSessionWithOptions(t, &Server{
3333
Handler: func(s Session) {
3434
v := s.Context().Value(key).(map[string]string)
3535
if v["foo"] != value["foo"] {

options_test.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ import (
77
gossh "golang.org/x/crypto/ssh"
88
)
99

10-
func newTestSessionWithOptions(t *testing.T, srv *Server, cfg *gossh.ClientConfig, options ...Option) (*gossh.Session, func()) {
10+
func newTestSessionWithOptions(t *testing.T, srv *Server, cfg *gossh.ClientConfig, options ...Option) (*gossh.Session, *gossh.Client, func()) {
1111
for _, option := range options {
1212
if err := srv.SetOption(option); err != nil {
1313
t.Fatal(err)
@@ -20,7 +20,7 @@ func TestPasswordAuth(t *testing.T) {
2020
t.Parallel()
2121
testUser := "testuser"
2222
testPass := "testpass"
23-
session, cleanup := newTestSessionWithOptions(t, &Server{
23+
session, _, cleanup := newTestSessionWithOptions(t, &Server{
2424
Handler: func(s Session) {
2525
// noop
2626
},

server.go

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -17,9 +17,10 @@ type Server struct {
1717
HostSigners []Signer // private keys for the host key, must have at least one
1818
Version string // server version to be sent before the initial handshake
1919

20-
PasswordHandler PasswordHandler // password authentication handler
21-
PublicKeyHandler PublicKeyHandler // public key authentication handler
22-
PtyCallback PtyCallback // callback for allowing PTY sessions, allows all if nil
20+
PasswordHandler PasswordHandler // password authentication handler
21+
PublicKeyHandler PublicKeyHandler // public key authentication handler
22+
PtyCallback PtyCallback // callback for allowing PTY sessions, allows all if nil
23+
LocalPortForwardingCallback LocalPortForwardingCallback // callback for allowing local port forwarding, denies all if nil
2324

2425
channelHandlers map[string]channelHandler
2526
}
@@ -40,7 +41,8 @@ func (srv *Server) ensureHostSigner() error {
4041

4142
func (srv *Server) config(ctx *sshContext) *gossh.ServerConfig {
4243
srv.channelHandlers = map[string]channelHandler{
43-
"session": sessionHandler,
44+
"session": sessionHandler,
45+
"direct-tcpip": directTcpipHandler,
4446
}
4547
config := &gossh.ServerConfig{}
4648
for _, signer := range srv.HostSigners {

session_test.go

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ func newLocalListener() net.Listener {
3232
return l
3333
}
3434

35-
func newClientSession(t *testing.T, addr string, config *gossh.ClientConfig) (*gossh.Session, func()) {
35+
func newClientSession(t *testing.T, addr string, config *gossh.ClientConfig) (*gossh.Session, *gossh.Client, func()) {
3636
if config == nil {
3737
config = &gossh.ClientConfig{
3838
User: "testuser",
@@ -52,13 +52,13 @@ func newClientSession(t *testing.T, addr string, config *gossh.ClientConfig) (*g
5252
if err != nil {
5353
t.Fatal(err)
5454
}
55-
return session, func() {
55+
return session, client, func() {
5656
session.Close()
5757
client.Close()
5858
}
5959
}
6060

61-
func newTestSession(t *testing.T, srv *Server, cfg *gossh.ClientConfig) (*gossh.Session, func()) {
61+
func newTestSession(t *testing.T, srv *Server, cfg *gossh.ClientConfig) (*gossh.Session, *gossh.Client, func()) {
6262
l := newLocalListener()
6363
go srv.serveOnce(l)
6464
return newClientSession(t, l.Addr().String(), cfg)
@@ -67,7 +67,7 @@ func newTestSession(t *testing.T, srv *Server, cfg *gossh.ClientConfig) (*gossh.
6767
func TestStdout(t *testing.T) {
6868
t.Parallel()
6969
testBytes := []byte("Hello world\n")
70-
session, cleanup := newTestSession(t, &Server{
70+
session, _, cleanup := newTestSession(t, &Server{
7171
Handler: func(s Session) {
7272
s.Write(testBytes)
7373
},
@@ -86,7 +86,7 @@ func TestStdout(t *testing.T) {
8686
func TestStderr(t *testing.T) {
8787
t.Parallel()
8888
testBytes := []byte("Hello world\n")
89-
session, cleanup := newTestSession(t, &Server{
89+
session, _, cleanup := newTestSession(t, &Server{
9090
Handler: func(s Session) {
9191
s.Stderr().Write(testBytes)
9292
},
@@ -105,7 +105,7 @@ func TestStderr(t *testing.T) {
105105
func TestStdin(t *testing.T) {
106106
t.Parallel()
107107
testBytes := []byte("Hello world\n")
108-
session, cleanup := newTestSession(t, &Server{
108+
session, _, cleanup := newTestSession(t, &Server{
109109
Handler: func(s Session) {
110110
io.Copy(s, s) // stdin back into stdout
111111
},
@@ -125,7 +125,7 @@ func TestStdin(t *testing.T) {
125125
func TestUser(t *testing.T) {
126126
t.Parallel()
127127
testUser := []byte("progrium")
128-
session, cleanup := newTestSession(t, &Server{
128+
session, _, cleanup := newTestSession(t, &Server{
129129
Handler: func(s Session) {
130130
io.WriteString(s, s.User())
131131
},
@@ -145,7 +145,7 @@ func TestUser(t *testing.T) {
145145

146146
func TestDefaultExitStatusZero(t *testing.T) {
147147
t.Parallel()
148-
session, cleanup := newTestSession(t, &Server{
148+
session, _, cleanup := newTestSession(t, &Server{
149149
Handler: func(s Session) {
150150
// noop
151151
},
@@ -159,7 +159,7 @@ func TestDefaultExitStatusZero(t *testing.T) {
159159

160160
func TestExplicitExitStatusZero(t *testing.T) {
161161
t.Parallel()
162-
session, cleanup := newTestSession(t, &Server{
162+
session, _, cleanup := newTestSession(t, &Server{
163163
Handler: func(s Session) {
164164
s.Exit(0)
165165
},
@@ -173,7 +173,7 @@ func TestExplicitExitStatusZero(t *testing.T) {
173173

174174
func TestExitStatusNonZero(t *testing.T) {
175175
t.Parallel()
176-
session, cleanup := newTestSession(t, &Server{
176+
session, _, cleanup := newTestSession(t, &Server{
177177
Handler: func(s Session) {
178178
s.Exit(1)
179179
},
@@ -195,7 +195,7 @@ func TestPty(t *testing.T) {
195195
winWidth := 40
196196
winHeight := 80
197197
done := make(chan bool)
198-
session, cleanup := newTestSession(t, &Server{
198+
session, _, cleanup := newTestSession(t, &Server{
199199
Handler: func(s Session) {
200200
ptyReq, _, isPty := s.Pty()
201201
if !isPty {
@@ -230,7 +230,7 @@ func TestPtyResize(t *testing.T) {
230230
winch2 := Window{20, 40}
231231
winches := make(chan Window)
232232
done := make(chan bool)
233-
session, cleanup := newTestSession(t, &Server{
233+
session, _, cleanup := newTestSession(t, &Server{
234234
Handler: func(s Session) {
235235
ptyReq, winCh, isPty := s.Pty()
236236
if !isPty {

ssh.go

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,9 @@ type PasswordHandler func(ctx Context, password string) bool
4242
// PtyCallback is a hook for allowing PTY sessions.
4343
type PtyCallback func(ctx Context, pty Pty) bool
4444

45+
// LocalPortForwardingCallback is a hook for allowing port forwarding
46+
type LocalPortForwardingCallback func(ctx Context, destinationHost string, destinationPort uint32) bool
47+
4548
// Window represents the size of a PTY window.
4649
type Window struct {
4750
Width int

tcpip.go

Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,58 @@
1+
package ssh
2+
3+
import (
4+
"fmt"
5+
"io"
6+
"net"
7+
8+
gossh "golang.org/x/crypto/ssh"
9+
)
10+
11+
// direct-tcpip data struct as specified in RFC4254, Section 7.2
12+
type forwardData struct {
13+
DestinationHost string
14+
DestinationPort uint32
15+
16+
OriginatorHost string
17+
OriginatorPort uint32
18+
}
19+
20+
func directTcpipHandler(srv *Server, conn *gossh.ServerConn, newChan gossh.NewChannel, ctx *sshContext) {
21+
d := forwardData{}
22+
if err := gossh.Unmarshal(newChan.ExtraData(), &d); err != nil {
23+
newChan.Reject(gossh.ConnectionFailed, "error parsing forward data: "+err.Error())
24+
return
25+
}
26+
27+
if srv.LocalPortForwardingCallback == nil || !srv.LocalPortForwardingCallback(ctx, d.DestinationHost, d.DestinationPort) {
28+
newChan.Reject(gossh.Prohibited, "port forwarding is disabled")
29+
return
30+
}
31+
32+
dest := fmt.Sprintf("%s:%d", d.DestinationHost, d.DestinationPort)
33+
34+
var dialer net.Dialer
35+
dconn, err := dialer.DialContext(ctx, "tcp", dest)
36+
if err != nil {
37+
newChan.Reject(gossh.ConnectionFailed, err.Error())
38+
return
39+
}
40+
41+
ch, reqs, err := newChan.Accept()
42+
if err != nil {
43+
dconn.Close()
44+
return
45+
}
46+
go gossh.DiscardRequests(reqs)
47+
48+
go func() {
49+
defer ch.Close()
50+
defer dconn.Close()
51+
io.Copy(ch, dconn)
52+
}()
53+
go func() {
54+
defer ch.Close()
55+
defer dconn.Close()
56+
io.Copy(dconn, ch)
57+
}()
58+
}

tcpip_test.go

Lines changed: 83 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,83 @@
1+
package ssh
2+
3+
import (
4+
"bytes"
5+
"fmt"
6+
"io/ioutil"
7+
"net"
8+
"strings"
9+
"testing"
10+
11+
gossh "golang.org/x/crypto/ssh"
12+
)
13+
14+
var sampleServerResponse = []byte("Hello world")
15+
16+
func sampleSocketServer() net.Listener {
17+
l := newLocalListener()
18+
19+
go func() {
20+
conn, err := l.Accept()
21+
if err != nil {
22+
return
23+
}
24+
conn.Write(sampleServerResponse)
25+
conn.Close()
26+
}()
27+
28+
return l
29+
}
30+
31+
func newTestSessionWithForwarding(t *testing.T, forwardingEnabled bool) (net.Listener, *gossh.Client, func()) {
32+
l := sampleSocketServer()
33+
34+
_, client, cleanup := newTestSession(t, &Server{
35+
Handler: func(s Session) {},
36+
LocalPortForwardingCallback: func(ctx Context, destinationHost string, destinationPort uint32) bool {
37+
addr := fmt.Sprintf("%s:%d", destinationHost, destinationPort)
38+
if addr != l.Addr().String() {
39+
panic("unexpected destinationHost: " + addr)
40+
}
41+
return forwardingEnabled
42+
},
43+
}, nil)
44+
45+
return l, client, func() {
46+
cleanup()
47+
l.Close()
48+
}
49+
}
50+
51+
func TestLocalPortForwardingWorks(t *testing.T) {
52+
t.Parallel()
53+
54+
l, client, cleanup := newTestSessionWithForwarding(t, true)
55+
defer cleanup()
56+
57+
conn, err := client.Dial("tcp", l.Addr().String())
58+
if err != nil {
59+
t.Fatalf("Error connecting to %v: %v", l.Addr().String(), err)
60+
}
61+
result, err := ioutil.ReadAll(conn)
62+
if err != nil {
63+
t.Fatal(err)
64+
}
65+
if !bytes.Equal(result, sampleServerResponse) {
66+
t.Fatalf("result = %#v; want %#v", result, sampleServerResponse)
67+
}
68+
}
69+
70+
func TestLocalPortForwardingRespectsCallback(t *testing.T) {
71+
t.Parallel()
72+
73+
l, client, cleanup := newTestSessionWithForwarding(t, false)
74+
defer cleanup()
75+
76+
_, err := client.Dial("tcp", l.Addr().String())
77+
if err == nil {
78+
t.Fatalf("Expected error connecting to %v but it succeeded", l.Addr().String())
79+
}
80+
if !strings.Contains(err.Error(), "port forwarding is disabled") {
81+
t.Fatalf("Expected permission error but got %#v", err)
82+
}
83+
}

0 commit comments

Comments
 (0)