diff --git a/internal/grpc/grpc.go b/internal/grpc/grpc.go index 34f6f3503..a93c6f9c2 100644 --- a/internal/grpc/grpc.go +++ b/internal/grpc/grpc.go @@ -363,30 +363,32 @@ func getTransportCredentials(agentConfig *config.Config) (credentials.TransportC if agentConfig.Command.TLS == nil { return defaultCredentials, nil } + tlsConfig, err := getTLSConfigForCredentials(agentConfig.Command.TLS) + if err != nil { + return nil, err + } + + return credentials.NewTLS(tlsConfig), nil +} - if agentConfig.Command.TLS.SkipVerify { +func getTLSConfigForCredentials(c *config.TLSConfig) (*tls.Config, error) { + if c.SkipVerify { slog.Warn("Verification of the server's certificate chain and host name is disabled") } tlsConfig := &tls.Config{ MinVersion: tls.VersionTLS12, - ServerName: agentConfig.Command.TLS.ServerName, - InsecureSkipVerify: agentConfig.Command.TLS.SkipVerify, + ServerName: c.ServerName, + InsecureSkipVerify: c.SkipVerify, } - if agentConfig.Command.TLS.Key == "" { - return credentials.NewTLS(tlsConfig), nil + if err := appendRootCAs(tlsConfig, c.Ca); err != nil { + return nil, fmt.Errorf("invalid CA cert while building transport credentials: %w", err) } - err := appendCertKeyPair(tlsConfig, agentConfig.Command.TLS.Cert, agentConfig.Command.TLS.Key) - if err != nil { - return nil, fmt.Errorf("append cert and key pair failed: %w", err) + if err := appendCertKeyPair(tlsConfig, c.Cert, c.Key); err != nil { + return nil, fmt.Errorf("invalid client cert while building transport credentials: %w", err) } - err = appendRootCAs(tlsConfig, agentConfig.Command.TLS.Ca) - if err != nil { - slog.Debug("Unable to append root CA", "error", err) - } - - return credentials.NewTLS(tlsConfig), nil + return tlsConfig, nil } diff --git a/internal/grpc/grpc_test.go b/internal/grpc/grpc_test.go index 4aea9ed6b..87828ceb3 100644 --- a/internal/grpc/grpc_test.go +++ b/internal/grpc/grpc_test.go @@ -7,19 +7,20 @@ package grpc import ( "context" + "crypto/tls" + "crypto/x509" "fmt" "testing" - "google.golang.org/grpc/credentials" - "github.com/cenkalti/backoff/v4" - "github.com/nginx/agent/v3/test/helpers" - "github.com/nginx/agent/v3/test/protos" "google.golang.org/grpc" "google.golang.org/grpc/codes" "google.golang.org/grpc/metadata" "google.golang.org/grpc/status" + "github.com/nginx/agent/v3/test/helpers" + "github.com/nginx/agent/v3/test/protos" + "github.com/nginx/agent/v3/internal/config" "github.com/nginx/agent/v3/test/types" @@ -103,7 +104,7 @@ func Test_GetDialOptions(t *testing.T) { { name: "Test 2: DialOptions mTLS", agentConfig: types.AgentConfig(), - expected: 7, + expected: 8, createCerts: true, }, { @@ -171,8 +172,8 @@ func Test_GetDialOptions(t *testing.T) { key, cert := helpers.GenerateSelfSignedCert(t) _, ca := helpers.GenerateSelfSignedCert(t) - keyContents := helpers.Cert{Name: keyFileName, Type: certificateType, Contents: key} - certContents := helpers.Cert{Name: certFileName, Type: privateKeyType, Contents: cert} + keyContents := helpers.Cert{Name: keyFileName, Type: privateKeyType, Contents: key} + certContents := helpers.Cert{Name: certFileName, Type: certificateType, Contents: cert} caContents := helpers.Cert{Name: caFileName, Type: certificateType, Contents: ca} helpers.WriteCertFiles(t, tmpDir, keyContents) @@ -356,28 +357,136 @@ func Test_ValidateGrpcError(t *testing.T) { } func Test_getTransportCredentials(t *testing.T) { - tests := []struct { - want credentials.TransportCredentials - conf *config.Config - wantErr assert.ErrorAssertionFunc - name string + tests := map[string]struct { + conf *config.Config + wantSecurityProfile string + wantServerName string + wantErr bool }{ - { - name: "No TLS config returns default credentials", + "Test 1: No TLS config returns default credentials": { conf: &config.Config{ Command: &config.Command{}, }, - want: defaultCredentials, - wantErr: assert.NoError, + wantErr: false, + wantSecurityProfile: "insecure", + }, + "Test 2: With tls config returns secure credentials": { + conf: &config.Config{ + Command: &config.Command{ + TLS: &config.TLSConfig{ + ServerName: "foobar", + SkipVerify: true, + }, + }, + }, + wantErr: false, + wantSecurityProfile: "tls", + }, + "Test 3: With invalid tls config should error": { + conf: types.AgentConfig(), // references non-existent certs + wantErr: true, }, } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { + for name, tt := range tests { + t.Run(name, func(t *testing.T) { got, err := getTransportCredentials(tt.conf) - if !tt.wantErr(t, err, fmt.Sprintf("getTransportCredentials(%v)", tt.conf)) { + if tt.wantErr { + require.Error(t, err, "getTransportCredentials(%v)", tt.conf) + return } - assert.Equalf(t, tt.want, got, "getTransportCredentials(%v)", tt.conf) + require.NoError(t, err, "getTransportCredentials(%v)", tt.conf) + require.Equal(t, tt.wantSecurityProfile, got.Info().SecurityProtocol, "incorrect SecurityProtocol") + }) + } +} + +func Test_getTLSConfig(t *testing.T) { + tmpDir := t.TempDir() + // not mTLS scripts + key, cert := helpers.GenerateSelfSignedCert(t) + _, ca := helpers.GenerateSelfSignedCert(t) + + keyContents := helpers.Cert{Name: keyFileName, Type: privateKeyType, Contents: key} + certContents := helpers.Cert{Name: certFileName, Type: certificateType, Contents: cert} + caContents := helpers.Cert{Name: caFileName, Type: certificateType, Contents: ca} + + keyPath := helpers.WriteCertFiles(t, tmpDir, keyContents) + certPath := helpers.WriteCertFiles(t, tmpDir, certContents) + caPath := helpers.WriteCertFiles(t, tmpDir, caContents) + + tests := map[string]struct { + conf *config.TLSConfig + verify func(require.TestingT, *tls.Config) + wantErr bool + }{ + "Test 1: all config should be translated": { + conf: &config.TLSConfig{ + Cert: certPath, + Key: keyPath, + Ca: caPath, + ServerName: "foobar", + SkipVerify: true, + }, + wantErr: false, + verify: func(t require.TestingT, c *tls.Config) { + require.NotEmpty(t, c.Certificates) + require.Equal(t, "foobar", c.ServerName, "wrong servername") + require.True(t, c.InsecureSkipVerify, "InsecureSkipVerify not set") + }, + }, + "Test 2: CA only config should use CA": { + conf: &config.TLSConfig{ + Ca: caPath, + }, + wantErr: false, + verify: func(t require.TestingT, c *tls.Config) { + require.NotNil(t, c.RootCAs, "RootCAs should be initialized") + require.False(t, x509.NewCertPool().Equal(c.RootCAs), + "CertPool shouldn't be empty, valid CA cert was specified") + require.False(t, c.InsecureSkipVerify, "InsecureSkipVerify should not be set") + }, + }, + "Test 3: incorrect CA should not error": { + conf: &config.TLSConfig{ + Ca: "customca.pem", + }, + wantErr: true, + }, + "Test 4: incorrect key path should error": { + conf: &config.TLSConfig{ + Ca: caPath, + Cert: certPath, + Key: "badkey", + }, + wantErr: true, + }, + "Test 5: incorrect cert path should error": { + conf: &config.TLSConfig{ + Ca: caPath, + Cert: "badcert", + Key: keyPath, + }, + wantErr: true, + }, + "Test 6: incomplete cert info should error": { + conf: &config.TLSConfig{ + Key: keyPath, + }, + wantErr: true, + }, + } + for name, tt := range tests { + t.Run(name, func(t *testing.T) { + got, err := getTLSConfigForCredentials(tt.conf) + if tt.wantErr { + require.Error(t, err, "getTLSConfigForCredentials(%v)", tt.conf) + return + } + require.NoError(t, err, "getTLSConfigForCredentials(%v)", tt.conf) + if tt.verify != nil { + tt.verify(t, got) + } }) } } diff --git a/test/helpers/cert_utils.go b/test/helpers/cert_utils.go index 04cacfa77..c292e97a3 100644 --- a/test/helpers/cert_utils.go +++ b/test/helpers/cert_utils.go @@ -11,10 +11,9 @@ import ( "crypto/x509" "crypto/x509/pkix" "encoding/pem" - "fmt" "math/big" "os" - "strings" + "path" "testing" "time" @@ -31,7 +30,7 @@ const ( permission = 0o600 serialNumber = 123123 years, months, days = 5, 0, 0 - bits = 4096 + bits = 1024 ) func GenerateSelfSignedCert(t testing.TB) (keyBytes, certBytes []byte) { @@ -73,12 +72,7 @@ func WriteCertFiles(t *testing.T, location string, cert Cert) string { Bytes: cert.Contents, }) - var certFile string - if strings.HasSuffix(location, string(os.PathSeparator)) { - certFile = fmt.Sprintf("%s%s", location, cert.Name) - } else { - certFile = fmt.Sprintf("%s%s%s", location, string(os.PathSeparator), cert.Name) - } + certFile := path.Join(location, cert.Name) err := os.WriteFile(certFile, pemContents, permission) require.NoError(t, err)