Skip to content

Commit 9b56478

Browse files
authored
contexts (gliderlabs#29)
* context: working mostly tested context implementation and refactoring to go with it * _example/ssh-publickey: updating new context based callbacks * godocs related to public api changes for contexts * context: converting []bytes to strings before putting into context Signed-off-by: Jeff Lindsay <[email protected]>
1 parent 791cd4b commit 9b56478

File tree

11 files changed

+343
-73
lines changed

11 files changed

+343
-73
lines changed

_example/ssh-publickey/public_key.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ func main() {
1616
s.Write(authorizedKey)
1717
})
1818

19-
publicKeyOption := ssh.PublicKeyAuth(func(user string, key ssh.PublicKey) bool {
19+
publicKeyOption := ssh.PublicKeyAuth(func(ctx ssh.Context, key ssh.PublicKey) bool {
2020
return true // allow all keys, or use ssh.KeysEqual() to compare against known keys
2121
})
2222

context.go

Lines changed: 142 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,142 @@
1+
package ssh
2+
3+
import (
4+
"context"
5+
"net"
6+
7+
gossh "golang.org/x/crypto/ssh"
8+
)
9+
10+
// contextKey is a value for use with context.WithValue. It's used as
11+
// a pointer so it fits in an interface{} without allocation.
12+
type contextKey struct {
13+
name string
14+
}
15+
16+
var (
17+
// ContextKeyUser is a context key for use with Contexts in this package.
18+
// The associated value will be of type string.
19+
ContextKeyUser = &contextKey{"user"}
20+
21+
// ContextKeySessionID is a context key for use with Contexts in this package.
22+
// The associated value will be of type string.
23+
ContextKeySessionID = &contextKey{"session-id"}
24+
25+
// ContextKeyPermissions is a context key for use with Contexts in this package.
26+
// The associated value will be of type *Permissions.
27+
ContextKeyPermissions = &contextKey{"permissions"}
28+
29+
// ContextKeyClientVersion is a context key for use with Contexts in this package.
30+
// The associated value will be of type string.
31+
ContextKeyClientVersion = &contextKey{"client-version"}
32+
33+
// ContextKeyServerVersion is a context key for use with Contexts in this package.
34+
// The associated value will be of type string.
35+
ContextKeyServerVersion = &contextKey{"server-version"}
36+
37+
// ContextKeyLocalAddr is a context key for use with Contexts in this package.
38+
// The associated value will be of type net.Addr.
39+
ContextKeyLocalAddr = &contextKey{"local-addr"}
40+
41+
// ContextKeyRemoteAddr is a context key for use with Contexts in this package.
42+
// The associated value will be of type net.Addr.
43+
ContextKeyRemoteAddr = &contextKey{"remote-addr"}
44+
45+
// ContextKeyServer is a context key for use with Contexts in this package.
46+
// The associated value will be of type *Server.
47+
ContextKeyServer = &contextKey{"ssh-server"}
48+
49+
// ContextKeyPublicKey is a context key for use with Contexts in this package.
50+
// The associated value will be of type PublicKey.
51+
ContextKeyPublicKey = &contextKey{"public-key"}
52+
)
53+
54+
// Context is a package specific context interface. It exposes connection
55+
// metadata and allows new values to be easily written to it. It's used in
56+
// authentication handlers and callbacks, and its underlying context.Context is
57+
// exposed on Session in the session Handler.
58+
type Context interface {
59+
context.Context
60+
61+
// User returns the username used when establishing the SSH connection.
62+
User() string
63+
64+
// SessionID returns the session hash.
65+
SessionID() string
66+
67+
// ClientVersion returns the version reported by the client.
68+
ClientVersion() string
69+
70+
// ServerVersion returns the version reported by the server.
71+
ServerVersion() string
72+
73+
// RemoteAddr returns the remote address for this connection.
74+
RemoteAddr() net.Addr
75+
76+
// LocalAddr returns the local address for this connection.
77+
LocalAddr() net.Addr
78+
79+
// Permissions returns the Permissions object used for this connection.
80+
Permissions() *Permissions
81+
82+
// SetValue allows you to easily write new values into the underlying context.
83+
SetValue(key, value interface{})
84+
}
85+
86+
type sshContext struct {
87+
context.Context
88+
}
89+
90+
func newContext(srv *Server) *sshContext {
91+
ctx := &sshContext{context.Background()}
92+
ctx.SetValue(ContextKeyServer, srv)
93+
perms := &Permissions{&gossh.Permissions{}}
94+
ctx.SetValue(ContextKeyPermissions, perms)
95+
return ctx
96+
}
97+
98+
// this is separate from newContext because we will get ConnMetadata
99+
// at different points so it needs to be applied separately
100+
func (ctx *sshContext) applyConnMetadata(conn gossh.ConnMetadata) {
101+
if ctx.Value(ContextKeySessionID) != nil {
102+
return
103+
}
104+
ctx.SetValue(ContextKeySessionID, string(conn.SessionID()))
105+
ctx.SetValue(ContextKeyClientVersion, string(conn.ClientVersion()))
106+
ctx.SetValue(ContextKeyServerVersion, string(conn.ServerVersion()))
107+
ctx.SetValue(ContextKeyUser, conn.User())
108+
ctx.SetValue(ContextKeyLocalAddr, conn.LocalAddr())
109+
ctx.SetValue(ContextKeyRemoteAddr, conn.RemoteAddr())
110+
}
111+
112+
func (ctx *sshContext) SetValue(key, value interface{}) {
113+
ctx.Context = context.WithValue(ctx.Context, key, value)
114+
}
115+
116+
func (ctx *sshContext) User() string {
117+
return ctx.Value(ContextKeyUser).(string)
118+
}
119+
120+
func (ctx *sshContext) SessionID() string {
121+
return ctx.Value(ContextKeySessionID).(string)
122+
}
123+
124+
func (ctx *sshContext) ClientVersion() string {
125+
return ctx.Value(ContextKeyClientVersion).(string)
126+
}
127+
128+
func (ctx *sshContext) ServerVersion() string {
129+
return ctx.Value(ContextKeyServerVersion).(string)
130+
}
131+
132+
func (ctx *sshContext) RemoteAddr() net.Addr {
133+
return ctx.Value(ContextKeyRemoteAddr).(net.Addr)
134+
}
135+
136+
func (ctx *sshContext) LocalAddr() net.Addr {
137+
return ctx.Value(ContextKeyLocalAddr).(net.Addr)
138+
}
139+
140+
func (ctx *sshContext) Permissions() *Permissions {
141+
return ctx.Value(ContextKeyPermissions).(*Permissions)
142+
}

context_test.go

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
package ssh
2+
3+
import "testing"
4+
5+
func TestSetPermissions(t *testing.T) {
6+
t.Parallel()
7+
permsExt := map[string]string{
8+
"foo": "bar",
9+
}
10+
session, cleanup := newTestSessionWithOptions(t, &Server{
11+
Handler: func(s Session) {
12+
if _, ok := s.Permissions().Extensions["foo"]; !ok {
13+
t.Fatalf("got %#v; want %#v", s.Permissions().Extensions, permsExt)
14+
}
15+
},
16+
}, nil, PasswordAuth(func(ctx Context, password string) bool {
17+
ctx.Permissions().Extensions = permsExt
18+
return true
19+
}))
20+
defer cleanup()
21+
if err := session.Run(""); err != nil {
22+
t.Fatal(err)
23+
}
24+
}
25+
26+
func TestSetValue(t *testing.T) {
27+
t.Parallel()
28+
value := map[string]string{
29+
"foo": "bar",
30+
}
31+
key := "testValue"
32+
session, cleanup := newTestSessionWithOptions(t, &Server{
33+
Handler: func(s Session) {
34+
v := s.Context().Value(key).(map[string]string)
35+
if v["foo"] != value["foo"] {
36+
t.Fatalf("got %#v; want %#v", v, value)
37+
}
38+
},
39+
}, nil, PasswordAuth(func(ctx Context, password string) bool {
40+
ctx.SetValue(key, value)
41+
return true
42+
}))
43+
defer cleanup()
44+
if err := session.Run(""); err != nil {
45+
t.Fatal(err)
46+
}
47+
}

example_test.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ func ExampleListenAndServe() {
1515

1616
func ExamplePasswordAuth() {
1717
ssh.ListenAndServe(":2222", nil,
18-
ssh.PasswordAuth(func(user, pass string) bool {
18+
ssh.PasswordAuth(func(ctx ssh.Context, pass string) bool {
1919
return pass == "secret"
2020
}),
2121
)
@@ -27,7 +27,7 @@ func ExampleNoPty() {
2727

2828
func ExamplePublicKeyAuth() {
2929
ssh.ListenAndServe(":2222", nil,
30-
ssh.PublicKeyAuth(func(user string, key ssh.PublicKey) bool {
30+
ssh.PublicKeyAuth(func(ctx ssh.Context, key ssh.PublicKey) bool {
3131
data, _ := ioutil.ReadFile("/path/to/allowed/key.pub")
3232
allowed, _, _, _, _ := ssh.ParseAuthorizedKey(data)
3333
return ssh.KeysEqual(key, allowed)

options.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,7 @@ func HostKeyPEM(bytes []byte) Option {
5656
// denying PTY requests.
5757
func NoPty() Option {
5858
return func(srv *Server) error {
59-
srv.PtyCallback = func(user string, permissions *Permissions) bool {
59+
srv.PtyCallback = func(ctx Context, pty Pty) bool {
6060
return false
6161
}
6262
return nil

options_test.go

Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,66 @@
1+
package ssh
2+
3+
import (
4+
"strings"
5+
"testing"
6+
7+
gossh "golang.org/x/crypto/ssh"
8+
)
9+
10+
func newTestSessionWithOptions(t *testing.T, srv *Server, cfg *gossh.ClientConfig, options ...Option) (*gossh.Session, func()) {
11+
for _, option := range options {
12+
if err := srv.SetOption(option); err != nil {
13+
t.Fatal(err)
14+
}
15+
}
16+
return newTestSession(t, srv, cfg)
17+
}
18+
19+
func TestPasswordAuth(t *testing.T) {
20+
t.Parallel()
21+
testUser := "testuser"
22+
testPass := "testpass"
23+
session, cleanup := newTestSessionWithOptions(t, &Server{
24+
Handler: func(s Session) {
25+
// noop
26+
},
27+
}, &gossh.ClientConfig{
28+
User: testUser,
29+
Auth: []gossh.AuthMethod{
30+
gossh.Password(testPass),
31+
},
32+
}, PasswordAuth(func(ctx Context, password string) bool {
33+
if ctx.User() != testUser {
34+
t.Fatalf("user = %#v; want %#v", ctx.User(), testUser)
35+
}
36+
if password != testPass {
37+
t.Fatalf("user = %#v; want %#v", password, testPass)
38+
}
39+
return true
40+
}))
41+
defer cleanup()
42+
if err := session.Run(""); err != nil {
43+
t.Fatal(err)
44+
}
45+
}
46+
47+
func TestPasswordAuthBadPass(t *testing.T) {
48+
t.Parallel()
49+
l := newLocalListener()
50+
srv := &Server{Handler: func(s Session) {}}
51+
srv.SetOption(PasswordAuth(func(ctx Context, password string) bool {
52+
return false
53+
}))
54+
go srv.serveOnce(l)
55+
_, err := gossh.Dial("tcp", l.Addr().String(), &gossh.ClientConfig{
56+
User: "testuser",
57+
Auth: []gossh.AuthMethod{
58+
gossh.Password("testpass"),
59+
},
60+
})
61+
if err != nil {
62+
if !strings.Contains(err.Error(), "unable to authenticate") {
63+
t.Fatal(err)
64+
}
65+
}
66+
}

0 commit comments

Comments
 (0)