Skip to content

Commit 50e82ed

Browse files
committed
fix: Use specified CA cert for grpc
This had been skipping out of the function early if a client key wasn't specified. I don't believe that's correct. If I[User] have specified specified a CA cert because the MPI server I'm trying to talk to is signed by a non-standard CA (e.g. N1 devenv) then it should be respected regardless of whether I've configured mTLS. Silently skipping the CA is really confusing and leads to > Failed to create connection" error="rpc error: code = Unavailable desc = connection error: desc = \"transport: authentication handshake failed: tls: failed to verify certificate: x509: certificate signed by unknown authority\" I've split out the getTLSConfigForCredentials to make it easier to test this translation. Once it is wrapped into a TransportCredential or a DialOption it's opaque and hard to verify.
1 parent 8d5f8a2 commit 50e82ed

File tree

3 files changed

+147
-37
lines changed

3 files changed

+147
-37
lines changed

internal/grpc/grpc.go

Lines changed: 16 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -363,30 +363,34 @@ func getTransportCredentials(agentConfig *config.Config) (credentials.TransportC
363363
if agentConfig.Command.TLS == nil {
364364
return defaultCredentials, nil
365365
}
366+
tlsConfig, err := getTLSConfigForCredentials(agentConfig.Command.TLS)
367+
if err != nil {
368+
return nil, err
369+
}
366370

367-
if agentConfig.Command.TLS.SkipVerify {
371+
return credentials.NewTLS(tlsConfig), nil
372+
}
373+
374+
func getTLSConfigForCredentials(c *config.TLSConfig) (*tls.Config, error) {
375+
if c.SkipVerify {
368376
slog.Warn("Verification of the server's certificate chain and host name is disabled")
369377
}
370378

371379
tlsConfig := &tls.Config{
372380
MinVersion: tls.VersionTLS12,
373-
ServerName: agentConfig.Command.TLS.ServerName,
374-
InsecureSkipVerify: agentConfig.Command.TLS.SkipVerify,
381+
ServerName: c.ServerName,
382+
InsecureSkipVerify: c.SkipVerify,
375383
}
376384

377-
if agentConfig.Command.TLS.Key == "" {
378-
return credentials.NewTLS(tlsConfig), nil
379-
}
380-
381-
err := appendCertKeyPair(tlsConfig, agentConfig.Command.TLS.Cert, agentConfig.Command.TLS.Key)
385+
err := appendRootCAs(tlsConfig, c.Ca)
382386
if err != nil {
383-
return nil, fmt.Errorf("append cert and key pair failed: %w", err)
387+
slog.Debug("Unable to append root CA", "error", err)
384388
}
385389

386-
err = appendRootCAs(tlsConfig, agentConfig.Command.TLS.Ca)
390+
err = appendCertKeyPair(tlsConfig, c.Cert, c.Key)
387391
if err != nil {
388-
slog.Debug("Unable to append root CA", "error", err)
392+
return nil, fmt.Errorf("append cert and key pair failed: %w", err)
389393
}
390394

391-
return credentials.NewTLS(tlsConfig), nil
395+
return tlsConfig, nil
392396
}

internal/grpc/grpc_test.go

Lines changed: 129 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -7,19 +7,20 @@ package grpc
77

88
import (
99
"context"
10+
"crypto/tls"
11+
"crypto/x509"
1012
"fmt"
1113
"testing"
1214

13-
"google.golang.org/grpc/credentials"
14-
1515
"github.com/cenkalti/backoff/v4"
16-
"github.com/nginx/agent/v3/test/helpers"
17-
"github.com/nginx/agent/v3/test/protos"
1816
"google.golang.org/grpc"
1917
"google.golang.org/grpc/codes"
2018
"google.golang.org/grpc/metadata"
2119
"google.golang.org/grpc/status"
2220

21+
"github.com/nginx/agent/v3/test/helpers"
22+
"github.com/nginx/agent/v3/test/protos"
23+
2324
"github.com/nginx/agent/v3/internal/config"
2425
"github.com/nginx/agent/v3/test/types"
2526

@@ -356,28 +357,139 @@ func Test_ValidateGrpcError(t *testing.T) {
356357
}
357358

358359
func Test_getTransportCredentials(t *testing.T) {
359-
tests := []struct {
360-
want credentials.TransportCredentials
361-
conf *config.Config
362-
wantErr assert.ErrorAssertionFunc
363-
name string
360+
tests := map[string]struct {
361+
conf *config.Config
362+
wantSecurityProfile string
363+
wantServerName string
364+
wantErr bool
364365
}{
365-
{
366-
name: "No TLS config returns default credentials",
366+
"Test 1: No TLS config returns default credentials": {
367367
conf: &config.Config{
368368
Command: &config.Command{},
369369
},
370-
want: defaultCredentials,
371-
wantErr: assert.NoError,
370+
wantErr: false,
371+
wantSecurityProfile: "insecure",
372+
},
373+
"Test 2: With tls config returns secure credentials": {
374+
conf: &config.Config{
375+
Command: &config.Command{
376+
TLS: &config.TLSConfig{
377+
ServerName: "foobar",
378+
SkipVerify: true,
379+
},
380+
},
381+
},
382+
wantErr: false,
383+
wantSecurityProfile: "tls",
384+
},
385+
"Test 3: With invalid tls config should error": {
386+
conf: types.AgentConfig(), // references non-existent certs
387+
wantErr: true,
372388
},
373389
}
374-
for _, tt := range tests {
375-
t.Run(tt.name, func(t *testing.T) {
390+
for name, tt := range tests {
391+
t.Run(name, func(t *testing.T) {
376392
got, err := getTransportCredentials(tt.conf)
377-
if !tt.wantErr(t, err, fmt.Sprintf("getTransportCredentials(%v)", tt.conf)) {
393+
if tt.wantErr {
394+
require.Error(t, err, "getTransportCredentials(%v)", tt.conf)
395+
396+
return
397+
}
398+
require.NoError(t, err, "getTransportCredentials(%v)", tt.conf)
399+
require.Equal(t, tt.wantSecurityProfile, got.Info().SecurityProtocol, "incorrect SecurityProtocol")
400+
})
401+
}
402+
}
403+
404+
func Test_getTLSConfig(t *testing.T) {
405+
tmpDir := t.TempDir()
406+
// not mTLS scripts
407+
key, cert := helpers.GenerateSelfSignedCert(t)
408+
_, ca := helpers.GenerateSelfSignedCert(t)
409+
410+
keyContents := helpers.Cert{Name: keyFileName, Type: privateKeyType, Contents: key}
411+
certContents := helpers.Cert{Name: certFileName, Type: certificateType, Contents: cert}
412+
caContents := helpers.Cert{Name: caFileName, Type: certificateType, Contents: ca}
413+
414+
keyPath := helpers.WriteCertFiles(t, tmpDir, keyContents)
415+
certPath := helpers.WriteCertFiles(t, tmpDir, certContents)
416+
caPath := helpers.WriteCertFiles(t, tmpDir, caContents)
417+
418+
tests := map[string]struct {
419+
conf *config.TLSConfig
420+
verify func(require.TestingT, *tls.Config)
421+
wantErr bool
422+
}{
423+
"Test 1: all config should be translated": {
424+
conf: &config.TLSConfig{
425+
Cert: certPath,
426+
Key: keyPath,
427+
Ca: caPath,
428+
ServerName: "foobar",
429+
SkipVerify: true,
430+
},
431+
wantErr: false,
432+
verify: func(t require.TestingT, c *tls.Config) {
433+
require.NotEmpty(t, c.Certificates)
434+
require.Equal(t, "foobar", c.ServerName, "wrong servername")
435+
require.True(t, c.InsecureSkipVerify, "InsecureSkipVerify not set")
436+
},
437+
},
438+
"Test 2: CA only config should use CA": {
439+
conf: &config.TLSConfig{
440+
Ca: caPath,
441+
},
442+
wantErr: false,
443+
verify: func(t require.TestingT, c *tls.Config) {
444+
require.NotNil(t, c.RootCAs, "RootCAs should be initialized")
445+
require.False(t, x509.NewCertPool().Equal(c.RootCAs),
446+
"CertPool shouldn't be empty, valid CA cert was specified")
447+
require.False(t, c.InsecureSkipVerify, "InsecureSkipVerify should not be set")
448+
},
449+
},
450+
"Test 3: incorrect CA should not error": { // REALLY ?!
451+
conf: &config.TLSConfig{
452+
Ca: "customca.pem",
453+
},
454+
wantErr: false,
455+
verify: func(t require.TestingT, c *tls.Config) {
456+
require.Nil(t, c.RootCAs, "RootCAs should be nil to use system")
457+
},
458+
},
459+
"Test 4: incorrect key path should error": {
460+
conf: &config.TLSConfig{
461+
Ca: caPath,
462+
Cert: certPath,
463+
Key: "badkey",
464+
},
465+
wantErr: true,
466+
},
467+
"Test 5: incorrect cert path should error": {
468+
conf: &config.TLSConfig{
469+
Ca: caPath,
470+
Cert: "badcert",
471+
Key: keyPath,
472+
},
473+
wantErr: true,
474+
},
475+
"Test 6: incomplete cert info should error": {
476+
conf: &config.TLSConfig{
477+
Key: keyPath,
478+
},
479+
wantErr: true,
480+
},
481+
}
482+
for name, tt := range tests {
483+
t.Run(name, func(t *testing.T) {
484+
got, err := getTLSConfigForCredentials(tt.conf)
485+
if tt.wantErr {
486+
require.Error(t, err, "getTLSConfigForCredentials(%v)", tt.conf)
378487
return
379488
}
380-
assert.Equalf(t, tt.want, got, "getTransportCredentials(%v)", tt.conf)
489+
require.NoError(t, err, "getTLSConfigForCredentials(%v)", tt.conf)
490+
if tt.verify != nil {
491+
tt.verify(t, got)
492+
}
381493
})
382494
}
383495
}

test/helpers/cert_utils.go

Lines changed: 2 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -11,10 +11,9 @@ import (
1111
"crypto/x509"
1212
"crypto/x509/pkix"
1313
"encoding/pem"
14-
"fmt"
1514
"math/big"
1615
"os"
17-
"strings"
16+
"path"
1817
"testing"
1918
"time"
2019

@@ -73,12 +72,7 @@ func WriteCertFiles(t *testing.T, location string, cert Cert) string {
7372
Bytes: cert.Contents,
7473
})
7574

76-
var certFile string
77-
if strings.HasSuffix(location, string(os.PathSeparator)) {
78-
certFile = fmt.Sprintf("%s%s", location, cert.Name)
79-
} else {
80-
certFile = fmt.Sprintf("%s%s%s", location, string(os.PathSeparator), cert.Name)
81-
}
75+
certFile := path.Join(location, cert.Name)
8276

8377
err := os.WriteFile(certFile, pemContents, permission)
8478
require.NoError(t, err)

0 commit comments

Comments
 (0)