diff --git a/doc.go b/doc.go index 5910ce8a..3cc37116 100644 --- a/doc.go +++ b/doc.go @@ -67,6 +67,47 @@ Valid values for sslmode are: the server was signed by a trusted CA and the server host name matches the one in the certificate) +For support ssl key in memory, we extend sslmode. For example: + + import ( + "crypto/tls" + "crypto/x509" + "io/ioutil" + "log" + + "github.com/lib/pq" + ) + + func main() { + rootCertPool := x509.NewCertPool() + pem, err := ioutil.ReadFile("ca.crt") + if err != nil { + log.Fatal(err) + } + if ok := rootCertPool.AppendCertsFromPEM(pem); !ok { + log.Fatal("Failed to append PEM.") + } + clientCert := make([]tls.Certificate, 0, 1) + certs, err := tls.LoadX509KeyPair("client1.crt", "client1.key") + if err != nil { + log.Fatal(err) + } + clientCert = append(clientCert, certs) + err = pq.RegisterTLSConfig("custom", &tls.Config{ + RootCAs: rootCertPool, + Certificates: clientCert, + ServerName: "pq.example.com", + }) + if err != nil { + log.Fatal(err) + } + connStr := "host=pq.example.com port=5432 user=user1 dbname=pqgotest password=pqgotest sslmode=custom" + db, err := sql.Open("postgres", connStr) + if err != nil { + log.Fatal(err) + } + } + See http://www.postgresql.org/docs/current/static/libpq-connect.html#LIBPQ-CONNSTRING for more information about connection string parameters. diff --git a/ssl.go b/ssl.go index b284c572..fcd224cd 100644 --- a/ssl.go +++ b/ssl.go @@ -10,19 +10,80 @@ import ( "path/filepath" "runtime" "strings" + "sync" "syscall" "github.com/lib/pq/internal/pqutil" ) +// Registry for custom tls.Configs +var ( + tlsConfs = make(map[string]*tls.Config) + tlsConfsMu sync.RWMutex +) + +// For example: +// +// import ( +// "crypto/tls" +// "crypto/x509" +// "io/ioutil" +// "log" +// +// "github.com/lib/pq" +// ) +// +// func main() { +// rootCertPool := x509.NewCertPool() +// pem, _ := ioutil.ReadFile("ca.crt") +// rootCertPool.AppendCertsFromPEM(pem) +// +// certs, _ := tls.LoadX509KeyPair("client1.crt", "client1.key") +// +// pq.RegisterTLSConfig("mytls", &tls.Config{ +// RootCAs: rootCertPool, +// Certificates: []tls.Certificate{certs}, +// ServerName: "pq.example.com", +// }) +// +// db, _ := sql.Open("postgres", "sslmode=pqgo-mytls") +// } +// +// Use nil for the config to remove. +func RegisterTLSConfig(key string, config *tls.Config) error { + if config == nil { + tlsConfsMu.Lock() + delete(tlsConfs, key) + tlsConfsMu.Unlock() + return nil + } + + tlsConfsMu.Lock() + tlsConfs[key] = config + tlsConfsMu.Unlock() + return nil +} + +func getTLSConfigClone(key string) *tls.Config { + tlsConfsMu.RLock() + if v, ok := tlsConfs[key]; ok { + return v.Clone() + } + tlsConfsMu.RUnlock() + return nil +} + // ssl generates a function to upgrade a net.Conn based on the "sslmode" and // related settings. The function is nil when no upgrade should take place. func ssl(o values) (func(net.Conn) (net.Conn, error), error) { - verifyCaOnly := false - tlsConf := tls.Config{} - switch mode := o["sslmode"]; mode { + var ( + verifyCaOnly = false + tlsConf = &tls.Config{} + mode = o["sslmode"] + ) + switch { // "require" is the default. - case "", "require": + case mode == "" || mode == "require": // We must skip TLS's own verification since it requires full // verification since Go 1.3. tlsConf.InsecureSkipVerify = true @@ -42,15 +103,20 @@ func ssl(o values) (func(net.Conn) (net.Conn, error), error) { delete(o, "sslrootcert") } } - case "verify-ca": + case mode == "verify-ca": // We must skip TLS's own verification since it requires full // verification since Go 1.3. tlsConf.InsecureSkipVerify = true verifyCaOnly = true - case "verify-full": + case mode == "verify-full": tlsConf.ServerName = o["host"] - case "disable": + case mode == "disable": return nil, nil + case strings.HasPrefix(mode, "pqgo-"): + tlsConf = getTLSConfigClone(mode[5:]) + if tlsConf == nil { + return nil, fmt.Errorf(`pq: unknown custom sslmode %q`, mode) + } default: return nil, fmt.Errorf( `pq: unsupported sslmode %q; only "require" (default), "verify-full", "verify-ca", and "disable" supported`, @@ -67,11 +133,11 @@ func ssl(o values) (func(net.Conn) (net.Conn, error), error) { tlsConf.ServerName = o["host"] } - err := sslClientCertificates(&tlsConf, o) + err := sslClientCertificates(tlsConf, o) if err != nil { return nil, err } - err = sslCertificateAuthority(&tlsConf, o) + err = sslCertificateAuthority(tlsConf, o) if err != nil { return nil, err } @@ -84,7 +150,7 @@ func ssl(o values) (func(net.Conn) (net.Conn, error), error) { tlsConf.Renegotiation = tls.RenegotiateFreelyAsClient return func(conn net.Conn) (net.Conn, error) { - client := tls.Client(conn, &tlsConf) + client := tls.Client(conn, tlsConf) if verifyCaOnly { err := client.Handshake() if err != nil {