Skip to content

Commit 1808146

Browse files
committed
fix: Use specified CA cert for grpc
This had been skipping out of the function early if a client key wasn't specified. I don't believe that's correct. If I[User] have specified specified a CA cert because the MPI server I'm trying to talk to is signed by a non-standard CA (e.g. N1 devenv) then it should be respected regardless of whether I've configured mTLS. Silently skipping the CA is really confusing and leads to > Failed to create connection" error="rpc error: code = Unavailable desc = connection error: desc = \"transport: authentication handshake failed: tls: failed to verify certificate: x509: certificate signed by unknown authority\" I've split out the getTLSConfigForCredentials to make it easier to test this translation. Once it is wrapped into a TransportCredential or a DialOption it's opaque and hard to verify.
1 parent 659c8af commit 1808146

File tree

3 files changed

+134
-36
lines changed

3 files changed

+134
-36
lines changed

internal/grpc/grpc.go

Lines changed: 15 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -363,30 +363,32 @@ func getTransportCredentials(agentConfig *config.Config) (credentials.TransportC
363363
if agentConfig.Command.TLS == nil {
364364
return defaultCredentials, nil
365365
}
366+
tlsConfig, err := getTLSConfigForCredentials(agentConfig.Command.TLS)
367+
if err != nil {
368+
return nil, err
369+
}
370+
return credentials.NewTLS(tlsConfig), nil
371+
}
366372

367-
if agentConfig.Command.TLS.SkipVerify {
373+
func getTLSConfigForCredentials(c *config.TLSConfig) (*tls.Config, error) {
374+
if c.SkipVerify {
368375
slog.Warn("Verification of the server's certificate chain and host name is disabled")
369376
}
370377

371378
tlsConfig := &tls.Config{
372379
MinVersion: tls.VersionTLS12,
373-
ServerName: agentConfig.Command.TLS.ServerName,
374-
InsecureSkipVerify: agentConfig.Command.TLS.SkipVerify,
375-
}
376-
377-
if agentConfig.Command.TLS.Key == "" {
378-
return credentials.NewTLS(tlsConfig), nil
380+
ServerName: c.ServerName,
381+
InsecureSkipVerify: c.SkipVerify,
379382
}
380383

381-
err := appendCertKeyPair(tlsConfig, agentConfig.Command.TLS.Cert, agentConfig.Command.TLS.Key)
384+
err := appendRootCAs(tlsConfig, c.Ca)
382385
if err != nil {
383-
return nil, fmt.Errorf("append cert and key pair failed: %w", err)
386+
slog.Debug("Unable to append root CA", "error", err)
384387
}
385388

386-
err = appendRootCAs(tlsConfig, agentConfig.Command.TLS.Ca)
389+
err = appendCertKeyPair(tlsConfig, c.Cert, c.Key)
387390
if err != nil {
388-
slog.Debug("Unable to append root CA", "error", err)
391+
return nil, fmt.Errorf("append cert and key pair failed: %w", err)
389392
}
390-
391-
return credentials.NewTLS(tlsConfig), nil
393+
return tlsConfig, nil
392394
}

internal/grpc/grpc_test.go

Lines changed: 117 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -7,11 +7,10 @@ package grpc
77

88
import (
99
"context"
10+
"crypto/tls"
1011
"fmt"
1112
"testing"
1213

13-
"google.golang.org/grpc/credentials"
14-
1514
"github.com/cenkalti/backoff/v4"
1615
"github.com/nginx/agent/v3/test/helpers"
1716
"github.com/nginx/agent/v3/test/protos"
@@ -356,28 +355,131 @@ func Test_ValidateGrpcError(t *testing.T) {
356355
}
357356

358357
func Test_getTransportCredentials(t *testing.T) {
359-
tests := []struct {
360-
want credentials.TransportCredentials
361-
conf *config.Config
362-
wantErr assert.ErrorAssertionFunc
363-
name string
358+
tests := map[string]struct {
359+
conf *config.Config
360+
wantErr bool
361+
wantSecurityProfile string
362+
wantServerName string
364363
}{
365-
{
366-
name: "No TLS config returns default credentials",
364+
"Test 1: No TLS config returns default credentials": {
367365
conf: &config.Config{
368366
Command: &config.Command{},
369367
},
370-
want: defaultCredentials,
371-
wantErr: assert.NoError,
368+
wantErr: false,
369+
wantSecurityProfile: "insecure",
370+
},
371+
"Test 2: With tls config returns secure credentials": {
372+
conf: &config.Config{
373+
Command: &config.Command{
374+
TLS: &config.TLSConfig{
375+
ServerName: "foobar",
376+
SkipVerify: true,
377+
},
378+
},
379+
},
380+
wantErr: false,
381+
wantSecurityProfile: "tls",
382+
},
383+
"Test 3: With invalid tls config should error": {
384+
conf: types.AgentConfig(), // references non-existant certs
385+
wantErr: true,
372386
},
373387
}
374-
for _, tt := range tests {
375-
t.Run(tt.name, func(t *testing.T) {
388+
for name, tt := range tests {
389+
t.Run(name, func(t *testing.T) {
376390
got, err := getTransportCredentials(tt.conf)
377-
if !tt.wantErr(t, err, fmt.Sprintf("getTransportCredentials(%v)", tt.conf)) {
391+
if tt.wantErr {
392+
require.Error(t, err, "getTransportCredentials(%v)", tt.conf)
393+
return
394+
}
395+
require.NoError(t, err, "getTransportCredentials(%v)", tt.conf)
396+
require.Equal(t, tt.wantSecurityProfile, got.Info().SecurityProtocol, "incorrect SecurityProtocol")
397+
})
398+
}
399+
}
400+
401+
func Test_getTLSConfig(t *testing.T) {
402+
tmpDir := t.TempDir()
403+
// not mTLS scripts
404+
key, cert := helpers.GenerateSelfSignedCert(t)
405+
_, ca := helpers.GenerateSelfSignedCert(t)
406+
407+
keyContents := helpers.Cert{Name: keyFileName, Type: privateKeyType, Contents: key}
408+
certContents := helpers.Cert{Name: certFileName, Type: certificateType, Contents: cert}
409+
caContents := helpers.Cert{Name: caFileName, Type: certificateType, Contents: ca}
410+
411+
keyPath := helpers.WriteCertFiles(t, tmpDir, keyContents)
412+
certPath := helpers.WriteCertFiles(t, tmpDir, certContents)
413+
caPath := helpers.WriteCertFiles(t, tmpDir, caContents)
414+
415+
tests := map[string]struct {
416+
conf *config.TLSConfig
417+
wantErr bool
418+
verify func(*testing.T, *tls.Config)
419+
}{
420+
"Test 1: all config should be translated": {
421+
conf: &config.TLSConfig{
422+
Cert: certPath,
423+
Key: keyPath,
424+
Ca: caPath,
425+
ServerName: "foobar",
426+
SkipVerify: true,
427+
},
428+
wantErr: false,
429+
verify: func(t *testing.T, c *tls.Config) {
430+
require.NotEmpty(t, c.Certificates)
431+
require.Equal(t, "foobar", c.ServerName, "wrong servername")
432+
require.True(t, c.InsecureSkipVerify, "InsecureSkipVerify not set")
433+
},
434+
},
435+
"Test 2: CA only config should use CA": {
436+
conf: &config.TLSConfig{
437+
Ca: caPath,
438+
},
439+
wantErr: false,
440+
verify: func(t *testing.T, c *tls.Config) {
441+
require.NotNil(t, c.RootCAs, "RootCAs should be initialized")
442+
require.Len(t, c.RootCAs.Subjects(), 1, "RootCAs pool should contain at least one subject")
443+
require.False(t, c.InsecureSkipVerify, "InsecureSkipVerify should not be set")
444+
},
445+
},
446+
"Test 3: incorrect CA should not error": { // REALLY ?!
447+
conf: &config.TLSConfig{
448+
Ca: "customca.pem",
449+
},
450+
wantErr: false,
451+
verify: func(t *testing.T, c *tls.Config) {
452+
require.Nil(t, c.RootCAs, "RootCAs should be nil to use system")
453+
},
454+
},
455+
"Test 4: incorrect key path should error": {
456+
conf: &config.TLSConfig{
457+
Ca: caPath,
458+
Cert: certPath,
459+
Key: "badkey",
460+
},
461+
wantErr: true,
462+
},
463+
"Test 5: incorrect cert path should error": {
464+
conf: &config.TLSConfig{
465+
Ca: caPath,
466+
Cert: "badcert",
467+
Key: keyPath,
468+
},
469+
wantErr: true,
470+
},
471+
}
472+
for name, tt := range tests {
473+
t.Run(name, func(t *testing.T) {
474+
got, err := getTLSConfigForCredentials(tt.conf)
475+
if tt.wantErr {
476+
require.Error(t, err, "getTLSConfigForCredentials(%v)", tt.conf)
378477
return
379478
}
380-
assert.Equalf(t, tt.want, got, "getTransportCredentials(%v)", tt.conf)
479+
require.NoError(t, err, "getTLSConfigForCredentials(%v)", tt.conf)
480+
if tt.verify != nil {
481+
tt.verify(t, got)
482+
}
381483
})
382484
}
383485
}

test/helpers/cert_utils.go

Lines changed: 2 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -11,10 +11,9 @@ import (
1111
"crypto/x509"
1212
"crypto/x509/pkix"
1313
"encoding/pem"
14-
"fmt"
1514
"math/big"
1615
"os"
17-
"strings"
16+
"path"
1817
"testing"
1918
"time"
2019

@@ -73,12 +72,7 @@ func WriteCertFiles(t *testing.T, location string, cert Cert) string {
7372
Bytes: cert.Contents,
7473
})
7574

76-
var certFile string
77-
if strings.HasSuffix(location, string(os.PathSeparator)) {
78-
certFile = fmt.Sprintf("%s%s", location, cert.Name)
79-
} else {
80-
certFile = fmt.Sprintf("%s%s%s", location, string(os.PathSeparator), cert.Name)
81-
}
75+
certFile := path.Join(location, cert.Name)
8276

8377
err := os.WriteFile(certFile, pemContents, permission)
8478
require.NoError(t, err)

0 commit comments

Comments
 (0)