Skip to content

Commit 9f2612e

Browse files
committed
Make authentication an interface for easy customization
1 parent dec4081 commit 9f2612e

File tree

4 files changed

+65
-32
lines changed

4 files changed

+65
-32
lines changed

server/auth.go

Lines changed: 1 addition & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -27,36 +27,7 @@ func (c *Conn) compareAuthData(authPluginName string, clientAuthData []byte) err
2727
return c.handleAuthSwitchResponse()
2828
}
2929

30-
switch authPluginName {
31-
case mysql.AUTH_NATIVE_PASSWORD:
32-
return c.compareNativePasswordAuthData(clientAuthData, c.credential)
33-
34-
case mysql.AUTH_CACHING_SHA2_PASSWORD:
35-
if !c.cachingSha2FullAuth {
36-
if err := c.compareCacheSha2PasswordAuthData(clientAuthData); err != nil {
37-
return err
38-
}
39-
if c.cachingSha2FullAuth {
40-
return c.handleAuthSwitchResponse()
41-
}
42-
return nil
43-
}
44-
// AuthMoreData packet already sent, do full auth
45-
return c.handleCachingSha2PasswordFullAuth(clientAuthData)
46-
47-
case mysql.AUTH_SHA256_PASSWORD:
48-
cont, err := c.handlePublicKeyRetrieval(clientAuthData)
49-
if err != nil {
50-
return err
51-
}
52-
if !cont {
53-
return nil
54-
}
55-
return c.compareSha256PasswordAuthData(clientAuthData, c.credential)
56-
57-
default:
58-
return errors.Errorf("unknown authentication plugin name '%s'", authPluginName)
59-
}
30+
return c.serverConf.authProvider.Authenticate(c, authPluginName, clientAuthData)
6031
}
6132

6233
func (c *Conn) acquirePassword() error {

server/authentication_provider.go

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
package server
2+
3+
import (
4+
"github.com/go-mysql-org/go-mysql/mysql"
5+
"github.com/pingcap/errors"
6+
)
7+
8+
type AuthenticationProvider interface {
9+
Authenticate(c *Conn, authPluginName string, clientAuthData []byte) error
10+
Validate(authPluginName string) bool
11+
}
12+
13+
type DefaultAuthenticationProvider struct{}
14+
15+
func (d *DefaultAuthenticationProvider) Authenticate(c *Conn, authPluginName string, clientAuthData []byte) error {
16+
switch authPluginName {
17+
case mysql.AUTH_NATIVE_PASSWORD:
18+
return c.compareNativePasswordAuthData(clientAuthData, c.credential)
19+
20+
case mysql.AUTH_CACHING_SHA2_PASSWORD:
21+
if !c.cachingSha2FullAuth {
22+
if err := c.compareCacheSha2PasswordAuthData(clientAuthData); err != nil {
23+
return err
24+
}
25+
if c.cachingSha2FullAuth {
26+
return c.handleAuthSwitchResponse()
27+
}
28+
return nil
29+
}
30+
// AuthMoreData packet already sent, do full auth
31+
return c.handleCachingSha2PasswordFullAuth(clientAuthData)
32+
33+
case mysql.AUTH_SHA256_PASSWORD:
34+
cont, err := c.handlePublicKeyRetrieval(clientAuthData)
35+
if err != nil {
36+
return err
37+
}
38+
if !cont {
39+
return nil
40+
}
41+
return c.compareSha256PasswordAuthData(clientAuthData, c.credential)
42+
43+
default:
44+
return errors.Errorf("unknown authentication plugin name '%s'", authPluginName)
45+
}
46+
}
47+
48+
func (d *DefaultAuthenticationProvider) Validate(authPluginName string) bool {
49+
return isAuthMethodSupported(authPluginName)
50+
}

server/conn.go

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,6 @@ func (s *Server) NewConn(conn net.Conn, user string, password string, h Handler)
6868
return s.NewCustomizedConn(conn, p, h)
6969
}
7070

71-
// NewCustomizedConn: create connection with customized server settings
7271
func (s *Server) NewCustomizedConn(conn net.Conn, p CredentialProvider, h Handler) (*Conn, error) {
7372
var packetConn *packet.Conn
7473
if s.tlsConfig != nil {
@@ -96,6 +95,11 @@ func (s *Server) NewCustomizedConn(conn net.Conn, p CredentialProvider, h Handle
9695
return c, nil
9796
}
9897

98+
// NewCustomizedConn: create connection with customized server settings
99+
// func (s *Server) NewCustomizedConn(conn net.Conn, p CredentialProvider, h Handler) (*Conn, error) {
100+
// return s.NewCustomizedConnWithAuth(conn, p, h, nil)
101+
// }
102+
99103
func (c *Conn) handshake() error {
100104
if err := c.writeInitialHandshake(); err != nil {
101105
return err

server/server_conf.go

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@ type Server struct {
3232
pubKey []byte
3333
tlsConfig *tls.Config
3434
cacheShaPassword *sync.Map // 'user@host' -> SHA256(SHA256(PASSWORD))
35+
authProvider AuthenticationProvider
3536
}
3637

3738
// NewDefaultServer: New mysql server with default settings.
@@ -56,6 +57,7 @@ func NewDefaultServer() *Server {
5657
pubKey: getPublicKeyFromCert(certPem),
5758
tlsConfig: tlsConf,
5859
cacheShaPassword: new(sync.Map),
60+
authProvider: &DefaultAuthenticationProvider{},
5961
}
6062
}
6163

@@ -69,7 +71,12 @@ func NewDefaultServer() *Server {
6971
// And for TLS support, you can specify self-signed or CA-signed certificates and decide whether the client needs to provide
7072
// a signed or unsigned certificate to provide different level of security.
7173
func NewServer(serverVersion string, collationId uint8, defaultAuthMethod string, pubKey []byte, tlsConfig *tls.Config) *Server {
72-
if !isAuthMethodSupported(defaultAuthMethod) {
74+
authProvider := &DefaultAuthenticationProvider{}
75+
return NewServerWithAuth(serverVersion, collationId, defaultAuthMethod, pubKey, tlsConfig, authProvider)
76+
}
77+
78+
func NewServerWithAuth(serverVersion string, collationId uint8, defaultAuthMethod string, pubKey []byte, tlsConfig *tls.Config, authProvider AuthenticationProvider) *Server {
79+
if authProvider == nil || !authProvider.Validate(defaultAuthMethod) {
7380
panic(fmt.Sprintf("server authentication method '%s' is not supported", defaultAuthMethod))
7481
}
7582

@@ -91,6 +98,7 @@ func NewServer(serverVersion string, collationId uint8, defaultAuthMethod string
9198
pubKey: pubKey,
9299
tlsConfig: tlsConfig,
93100
cacheShaPassword: new(sync.Map),
101+
authProvider: authProvider,
94102
}
95103
}
96104

0 commit comments

Comments
 (0)