diff --git a/client/errs/errno.go b/client/errs/errno.go index 95c6bffdfa4..752df99840f 100644 --- a/client/errs/errno.go +++ b/client/errs/errno.go @@ -61,6 +61,7 @@ var ( // grpcutil errors var ( ErrSecurityConfig = errors.Normalize("security config error: %s", errors.RFCCodeText("PD:grpcutil:ErrSecurityConfig")) + ErrTLSConfig = errors.Normalize("TLS config error", errors.RFCCodeText("PD:grpcutil:ErrTLSConfig")) ) // The third-party project error. @@ -75,11 +76,6 @@ var ( ErrCloseGRPCConn = errors.Normalize("close gRPC connection failed", errors.RFCCodeText("PD:grpc:ErrCloseGRPCConn")) ) -// etcd errors -var ( - ErrEtcdTLSConfig = errors.Normalize("etcd TLS config error", errors.RFCCodeText("PD:etcd:ErrEtcdTLSConfig")) -) - // crypto var ( ErrCryptoX509KeyPair = errors.Normalize("x509 keypair error", errors.RFCCodeText("PD:crypto:ErrCryptoX509KeyPair")) diff --git a/client/tlsutil/tlsconfig.go b/client/tlsutil/tlsconfig.go index 88d797d3b3a..c7f1edf305f 100644 --- a/client/tlsutil/tlsconfig.go +++ b/client/tlsutil/tlsconfig.go @@ -206,7 +206,7 @@ func (s TLSConfig) ToTLSConfig() (*tls.Config, error) { tlsConfig, err := tlsInfo.clientConfig() if err != nil { - return nil, errs.ErrEtcdTLSConfig.Wrap(err).GenWithStackByCause() + return nil, errs.ErrTLSConfig.Wrap(err).GenWithStackByCause() } return tlsConfig, nil } diff --git a/errors.toml b/errors.toml index 47de8a80211..6ecc8ff8212 100644 --- a/errors.toml +++ b/errors.toml @@ -376,11 +376,6 @@ error = ''' etcd move leader error ''' -["PD:etcd:ErrEtcdTLSConfig"] -error = ''' -etcd TLS config error -''' - ["PD:etcd:ErrEtcdTxnConflict"] error = ''' etcd transaction failed, conflicted and rolled back @@ -456,6 +451,11 @@ error = ''' security config error: %s ''' +["PD:grpcutil:ErrTLSConfig"] +error = ''' +TLS config error +''' + ["PD:hex:ErrHexDecodingString"] error = ''' decode string %s error diff --git a/pkg/dashboard/adapter/config.go b/pkg/dashboard/adapter/config.go index cf79690e73b..15fce8cf8bd 100644 --- a/pkg/dashboard/adapter/config.go +++ b/pkg/dashboard/adapter/config.go @@ -36,7 +36,7 @@ func GenDashboardConfig(srv *server.Server) (*config.Config, error) { dashboardCfg.EnableTelemetry = cfg.Dashboard.EnableTelemetry dashboardCfg.EnableExperimental = cfg.Dashboard.EnableExperimental dashboardCfg.DisableCustomPromAddr = cfg.Dashboard.DisableCustomPromAddr - if dashboardCfg.ClusterTLSConfig, err = cfg.Security.ToTLSConfig(); err != nil { + if dashboardCfg.ClusterTLSConfig, err = cfg.Security.ToClientTLSConfig(); err != nil { return nil, err } if dashboardCfg.ClusterTLSInfo, err = cfg.Security.ToTLSInfo(); err != nil { diff --git a/pkg/errs/errno.go b/pkg/errs/errno.go index 668059ec8aa..3e460e9d637 100644 --- a/pkg/errs/errno.go +++ b/pkg/errs/errno.go @@ -206,6 +206,7 @@ var ( // grpcutil errors var ( ErrSecurityConfig = errors.Normalize("security config error: %s", errors.RFCCodeText("PD:grpcutil:ErrSecurityConfig")) + ErrTLSConfig = errors.Normalize("TLS config error", errors.RFCCodeText("PD:grpcutil:ErrTLSConfig")) ) // server errors @@ -268,7 +269,6 @@ var ( ErrEtcdKVGetResponse = errors.Normalize("etcd invalid get value response %v, must only one", errors.RFCCodeText("PD:etcd:ErrEtcdKVGetResponse")) ErrEtcdGetCluster = errors.Normalize("etcd get cluster from remote peer failed", errors.RFCCodeText("PD:etcd:ErrEtcdGetCluster")) ErrEtcdMoveLeader = errors.Normalize("etcd move leader error", errors.RFCCodeText("PD:etcd:ErrEtcdMoveLeader")) - ErrEtcdTLSConfig = errors.Normalize("etcd TLS config error", errors.RFCCodeText("PD:etcd:ErrEtcdTLSConfig")) ErrEtcdWatcherCancel = errors.Normalize("watcher canceled", errors.RFCCodeText("PD:etcd:ErrEtcdWatcherCancel")) ErrCloseEtcdClient = errors.Normalize("close etcd client failed", errors.RFCCodeText("PD:etcd:ErrCloseEtcdClient")) ErrEtcdMemberList = errors.Normalize("etcd member list failed", errors.RFCCodeText("PD:etcd:ErrEtcdMemberList")) diff --git a/pkg/mcs/server/server.go b/pkg/mcs/server/server.go index fef05a85012..8cfdd4395fb 100644 --- a/pkg/mcs/server/server.go +++ b/pkg/mcs/server/server.go @@ -66,7 +66,7 @@ func (bs *BaseServer) Context() context.Context { func (bs *BaseServer) GetDelegateClient(ctx context.Context, tlsCfg *grpcutil.TLSConfig, forwardedHost string) (*grpc.ClientConn, error) { client, ok := bs.clientConns.Load(forwardedHost) if !ok { - tlsConfig, err := tlsCfg.ToTLSConfig() + tlsConfig, err := tlsCfg.ToClientTLSConfig() if err != nil { return nil, err } @@ -146,7 +146,7 @@ func (bs *BaseServer) InitListener(tlsCfg *grpcutil.TLSConfig, listenAddr string if err != nil { return err } - tlsConfig, err := tlsCfg.ToTLSConfig() + tlsConfig, err := tlsCfg.ToServerTLSConfig() if err != nil { return err } diff --git a/pkg/mcs/utils/util.go b/pkg/mcs/utils/util.go index c97d523edd2..9139d3c3e4e 100644 --- a/pkg/mcs/utils/util.go +++ b/pkg/mcs/utils/util.go @@ -102,7 +102,7 @@ type server interface { // InitClient initializes the etcd and http clients. func InitClient(s server) error { - tlsConfig, err := s.GetTLSConfig().ToTLSConfig() + tlsConfig, err := s.GetTLSConfig().ToClientTLSConfig() if err != nil { return err } diff --git a/pkg/tso/allocator_manager.go b/pkg/tso/allocator_manager.go index a02d4884e17..26c70298806 100644 --- a/pkg/tso/allocator_manager.go +++ b/pkg/tso/allocator_manager.go @@ -1223,7 +1223,7 @@ func (am *AllocatorManager) getOrCreateGRPCConn(ctx context.Context, addr string if ok { return conn, nil } - tlsCfg, err := am.securityConfig.ToTLSConfig() + tlsCfg, err := am.securityConfig.ToClientTLSConfig() if err != nil { return nil, err } diff --git a/pkg/utils/grpcutil/grpcutil.go b/pkg/utils/grpcutil/grpcutil.go index 4178617cd74..0e6b8a453de 100644 --- a/pkg/utils/grpcutil/grpcutil.go +++ b/pkg/utils/grpcutil/grpcutil.go @@ -74,8 +74,8 @@ func (s TLSConfig) ToTLSInfo() (*transport.TLSInfo, error) { }, nil } -// ToTLSConfig generates tls config. -func (s TLSConfig) ToTLSConfig() (*tls.Config, error) { +// ToClientTLSConfig generates tls config. +func (s TLSConfig) ToClientTLSConfig() (*tls.Config, error) { if len(s.SSLCABytes) != 0 || len(s.SSLCertBytes) != 0 || len(s.SSLKEYBytes) != 0 { cert, err := tls.X509KeyPair(s.SSLCertBytes, s.SSLKEYBytes) if err != nil { @@ -100,16 +100,39 @@ func (s TLSConfig) ToTLSConfig() (*tls.Config, error) { return nil, nil } if err != nil { - return nil, errs.ErrEtcdTLSConfig.Wrap(err).GenWithStackByCause() + return nil, errs.ErrTLSConfig.Wrap(err).GenWithStackByCause() } tlsConfig, err := tlsInfo.ClientConfig() if err != nil { - return nil, errs.ErrEtcdTLSConfig.Wrap(err).GenWithStackByCause() + return nil, errs.ErrTLSConfig.Wrap(err).GenWithStackByCause() } return tlsConfig, nil } +// ToServerTLSConfig generates tls config. +func (s TLSConfig) ToServerTLSConfig() (*tls.Config, error) { + if len(s.CertPath) == 0 && len(s.KeyPath) == 0 { + return nil, nil + } + + tlsInfo := transport.TLSInfo{ + CertFile: s.CertPath, + KeyFile: s.KeyPath, + TrustedCAFile: s.CAPath, + AllowedCNs: s.CertAllowedCNs, + } + + tlsConfig, err := tlsInfo.ServerConfig() + if err != nil { + return nil, errs.ErrTLSConfig.Wrap(err).GenWithStackByCause() + } + tlsConfig.NextProtos = []string{"http/1.1", "h2"} + tlsConfig.ClientAuth = tls.VerifyClientCertIfGiven + + return tlsConfig, nil +} + // GetClientConn returns a gRPC client connection. // creates a client connection to the given target. By default, it's // a non-blocking dial (the function won't wait for connections to be @@ -189,7 +212,7 @@ func IsFollowerHandleEnabled(ctx context.Context) bool { } func establish(ctx context.Context, addr string, tlsConfig *TLSConfig, do ...grpc.DialOption) (*grpc.ClientConn, error) { - tlsCfg, err := tlsConfig.ToTLSConfig() + tlsCfg, err := tlsConfig.ToClientTLSConfig() if err != nil { return nil, err } diff --git a/pkg/utils/grpcutil/grpcutil_test.go b/pkg/utils/grpcutil/grpcutil_test.go index fbcfe59f02c..2e83403195b 100644 --- a/pkg/utils/grpcutil/grpcutil_test.go +++ b/pkg/utils/grpcutil/grpcutil_test.go @@ -2,6 +2,8 @@ package grpcutil import ( "context" + "crypto/tls" + "crypto/x509" "os" "os/exec" "path/filepath" @@ -29,7 +31,8 @@ func loadTLSContent(re *require.Assertions, caPath, certPath, keyPath string) (c return } -func TestToTLSConfig(t *testing.T) { +func TestToClientTLSConfig(t *testing.T) { + re := require.New(t) if err := exec.Command(certScript, "generate", certPath).Run(); err != nil { t.Fatal(err) } @@ -39,14 +42,13 @@ func TestToTLSConfig(t *testing.T) { } }() - re := require.New(t) tlsConfig := TLSConfig{ KeyPath: filepath.Join(certPath, "pd-server-key.pem"), CertPath: filepath.Join(certPath, "pd-server.pem"), CAPath: filepath.Join(certPath, "ca.pem"), } // test without bytes - _, err := tlsConfig.ToTLSConfig() + _, err := tlsConfig.ToClientTLSConfig() re.NoError(err) // test with bytes @@ -54,21 +56,124 @@ func TestToTLSConfig(t *testing.T) { tlsConfig.SSLCABytes = caData tlsConfig.SSLCertBytes = certData tlsConfig.SSLKEYBytes = keyData - _, err = tlsConfig.ToTLSConfig() + _, err = tlsConfig.ToClientTLSConfig() re.NoError(err) // test wrong cert bytes tlsConfig.SSLCertBytes = []byte("invalid cert") - _, err = tlsConfig.ToTLSConfig() + _, err = tlsConfig.ToClientTLSConfig() re.True(errors.ErrorEqual(err, errs.ErrCryptoX509KeyPair)) // test wrong ca bytes tlsConfig.SSLCertBytes = certData tlsConfig.SSLCABytes = []byte("invalid ca") - _, err = tlsConfig.ToTLSConfig() + _, err = tlsConfig.ToClientTLSConfig() re.True(errors.ErrorEqual(err, errs.ErrCryptoAppendCertsFromPEM)) } +func TestToServerTLSConfig(t *testing.T) { + re := require.New(t) + + if err := exec.Command(certScript, "generate", certPath).Run(); err != nil { + t.Fatal(err) + } + defer func() { + if err := exec.Command(certScript, "cleanup", certPath).Run(); err != nil { + t.Fatal(err) + } + }() + + testCases := []struct { + name string + tlsConfig TLSConfig + wantErr bool + checkConfig bool + allowedCNs []string + validateProto bool + }{ + { + name: "valid certificate configuration", + tlsConfig: TLSConfig{ + KeyPath: filepath.Join(certPath, "pd-server-key.pem"), + CertPath: filepath.Join(certPath, "pd-server.pem"), + CAPath: filepath.Join(certPath, "ca.pem"), + }, + checkConfig: true, + validateProto: true, + }, + { + name: "with allowed CNs", + tlsConfig: TLSConfig{ + KeyPath: filepath.Join(certPath, "pd-server-key.pem"), + CertPath: filepath.Join(certPath, "pd-server.pem"), + CAPath: filepath.Join(certPath, "ca.pem"), + CertAllowedCNs: []string{"pd-server"}, + }, + checkConfig: true, + allowedCNs: []string{"pd-server"}, + }, + { + name: "empty cert and key paths", + tlsConfig: TLSConfig{ + CAPath: filepath.Join(certPath, "ca.pem"), + }, + wantErr: false, // Should return nil config, not error + }, + { + name: "invalid cert path", + tlsConfig: TLSConfig{ + CAPath: filepath.Join(certPath, "ca.pem"), + CertPath: "non-existent.pem", + KeyPath: filepath.Join(certPath, "pd-server-key.pem"), + }, + wantErr: true, + }, + { + name: "invalid key path", + tlsConfig: TLSConfig{ + CAPath: filepath.Join(certPath, "ca.pem"), + CertPath: filepath.Join(certPath, "pd-server-key.pem"), + KeyPath: "non-existent.pem", + }, + wantErr: true, + }, + } + + for _, testCase := range testCases { + t.Run(testCase.name, func(_ *testing.T) { + tlsConfig, err := testCase.tlsConfig.ToServerTLSConfig() + if testCase.wantErr { + re.Error(err) + return + } + re.NoError(err) + + if !testCase.checkConfig { + if testCase.tlsConfig.CertPath == "" && testCase.tlsConfig.KeyPath == "" { + re.Nil(tlsConfig) + } + return + } + + re.NotNil(tlsConfig) + re.Equal(tls.VerifyClientCertIfGiven, tlsConfig.ClientAuth) + + if testCase.validateProto { + re.Contains(tlsConfig.NextProtos, "http/1.1") + re.Contains(tlsConfig.NextProtos, "h2") + } + + // Validate allowed CNs + if len(testCase.allowedCNs) > 0 { + cert, err := tls.LoadX509KeyPair(testCase.tlsConfig.CertPath, testCase.tlsConfig.KeyPath) + re.NoError(err) + x509Cert, err := x509.ParseCertificate(cert.Certificate[0]) + re.NoError(err) + re.Contains(testCase.allowedCNs, x509Cert.Subject.CommonName) + } + }) + } +} func BenchmarkGetForwardedHost(b *testing.B) { // Without forwarded host key md := metadata.Pairs("test", "example.com") diff --git a/server/config/config_test.go b/server/config/config_test.go index f326d42cdde..93a094f6319 100644 --- a/server/config/config_test.go +++ b/server/config/config_test.go @@ -42,7 +42,7 @@ func TestSecurity(t *testing.T) { func TestTLS(t *testing.T) { re := require.New(t) cfg := NewConfig() - tls, err := cfg.Security.ToTLSConfig() + tls, err := cfg.Security.ToClientTLSConfig() re.NoError(err) re.Nil(tls) } diff --git a/server/forward.go b/server/forward.go index 707b63a15ad..722594891ac 100644 --- a/server/forward.go +++ b/server/forward.go @@ -372,7 +372,7 @@ func (s *GrpcServer) getDelegateClient(ctx context.Context, forwardedHost string return client.(*grpc.ClientConn), nil } - tlsConfig, err := s.GetTLSConfig().ToTLSConfig() + tlsConfig, err := s.GetTLSConfig().ToClientTLSConfig() if err != nil { return nil, err } diff --git a/server/join/join.go b/server/join/join.go index e77675f2196..b0470f48472 100644 --- a/server/join/join.go +++ b/server/join/join.go @@ -111,7 +111,7 @@ func PrepareJoinCluster(cfg *config.Config) error { } // Below are cases without data directory. - tlsConfig, err := cfg.Security.ToTLSConfig() + tlsConfig, err := cfg.Security.ToClientTLSConfig() if err != nil { return err } diff --git a/server/server.go b/server/server.go index 813a62d4569..87ac5d05987 100644 --- a/server/server.go +++ b/server/server.go @@ -338,7 +338,7 @@ func (s *Server) startEtcd(ctx context.Context) error { if err != nil { return errs.ErrEtcdURLMap.Wrap(err).GenWithStackByCause() } - tlsConfig, err := s.cfg.Security.ToTLSConfig() + tlsConfig, err := s.cfg.Security.ToClientTLSConfig() if err != nil { return err } @@ -379,7 +379,7 @@ func (s *Server) initGRPCServiceLabels() { } func (s *Server) startClient() error { - tlsConfig, err := s.cfg.Security.ToTLSConfig() + tlsConfig, err := s.cfg.Security.ToClientTLSConfig() if err != nil { return err } diff --git a/server/server_test.go b/server/server_test.go index 7dd91b9f61f..344df5693a1 100644 --- a/server/server_test.go +++ b/server/server_test.go @@ -324,7 +324,7 @@ func TestCheckClusterID(t *testing.T) { re.NoError(err) urlsMap, err := etcdtypes.NewURLsMap(svr.cfg.InitialCluster) re.NoError(err) - tlsConfig, err := svr.cfg.Security.ToTLSConfig() + tlsConfig, err := svr.cfg.Security.ToClientTLSConfig() re.NoError(err) err = etcdutil.CheckClusterID(etcd.Server.Cluster().ID(), urlsMap, tlsConfig) re.Error(err) diff --git a/tests/integrations/client/client_tls_test.go b/tests/integrations/client/client_tls_test.go index 75706c3d902..ce47e05a17a 100644 --- a/tests/integrations/client/client_tls_test.go +++ b/tests/integrations/client/client_tls_test.go @@ -146,7 +146,7 @@ func testTLSReload( endpoints := make([]string, 0, len(testServers)) for _, s := range testServers { endpoints = append(endpoints, s.GetConfig().AdvertiseClientUrls) - tlsConfig, err := s.GetConfig().Security.ToTLSConfig() + tlsConfig, err := s.GetConfig().Security.ToClientTLSConfig() re.NoError(err) httpClient := &http.Client{ Transport: &http.Transport{ diff --git a/tools/pd-heartbeat-bench/main.go b/tools/pd-heartbeat-bench/main.go index daf7c949bda..50e1fba1036 100644 --- a/tools/pd-heartbeat-bench/main.go +++ b/tools/pd-heartbeat-bench/main.go @@ -70,7 +70,7 @@ var ( ) func newClient(ctx context.Context, cfg *config.Config) (pdpb.PDClient, error) { - tlsConfig, err := cfg.Security.ToTLSConfig() + tlsConfig, err := cfg.Security.ToClientTLSConfig() if err != nil { return nil, err }