Skip to content

Commit

Permalink
Add simple vault client auth test (spiffe#5058)
Browse files Browse the repository at this point in the history
  • Loading branch information
InverseIntegral committed Sep 8, 2024
1 parent ed6cb04 commit bb9fb10
Show file tree
Hide file tree
Showing 7 changed files with 461 additions and 14 deletions.
12 changes: 4 additions & 8 deletions pkg/server/plugin/keymanager/hashicorpvault/hashicorp_vault.go
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,6 @@ type pluginHooks struct {
refreshKeysSignal chan error
disposeKeysSignal chan error

newClient func(*ClientConfig, AuthMethod, chan struct{}) (client cloudKeyManagementService, err error)
lookupEnv func(string) (string, bool)
}

Expand Down Expand Up @@ -76,25 +75,22 @@ type Plugin struct {

authMethod AuthMethod
cc *ClientConfig
vc cloudKeyManagementService
vc *Client

hooks pluginHooks
}

// New returns an instantiated plugin.
func New() *Plugin {
return newPlugin(func(config *ClientConfig, method AuthMethod, renewCh chan struct{}) (client cloudKeyManagementService, err error) {
return config.NewAuthenticatedClient(method, renewCh)
})
return newPlugin()
}

// newPlugin returns a new plugin instance.
func newPlugin(newClient func(*ClientConfig, AuthMethod, chan struct{}) (client cloudKeyManagementService, err error)) *Plugin {
func newPlugin() *Plugin {
return &Plugin{
entries: make(map[string]keyEntry),
hooks: pluginHooks{
lookupEnv: os.LookupEnv,
newClient: newClient,
},
}
}
Expand Down Expand Up @@ -360,7 +356,7 @@ func convertToTransitKeyType(keyType keymanagerv1.KeyType) (*TransitKeyType, err

func (p *Plugin) genVaultClient() error {
renewCh := make(chan struct{})
vc, err := p.hooks.newClient(p.cc, p.authMethod, renewCh)
vc, err := p.cc.NewAuthenticatedClient(p.authMethod, renewCh)
if err != nil {
return status.Errorf(codes.Internal, "failed to prepare authenticated client: %v", err)
}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
-----BEGIN CERTIFICATE-----
MIIBMjCB2aADAgECAgEBMAoGCCqGSM49BAMCMAAwIBgPMDAwMTAxMDEwMDAwMDBa
Fw0zMjA0MTIxNjA4NDRaMAAwWTATBgcqhkjOPQIBBggqhkjOPQMBBwNCAAQaWBAL
TN4YPe4yQgMhDp9DZOPXaglEchzUo++feITLXN9XuUICLNWO9YEtAsaRsajul8Bc
GL9Rmbv2f6J2Lnueo0IwQDAOBgNVHQ8BAf8EBAMCAgQwDwYDVR0TAQH/BAUwAwEB
/zAdBgNVHQ4EFgQUmEs2MBzULBomV0lWA7OfcN/lGDcwCgYIKoZIzj0EAwIDSAAw
RQIhAP86wRV1PHg6rFkjl1Nx6He+Y2LSdOoEGnGlVM0ztzlUAiBpPhSMqonlFLZa
nLW9psyWrQMHai7KZLJjLfw+UMl0sQ==
-----END CERTIFICATE-----
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
-----BEGIN CERTIFICATE-----
MIIBPTCB46ADAgECAgECMAoGCCqGSM49BAMCMAAwIBgPMDAwMTAxMDEwMDAwMDBa
Fw0zMjA0MTIxNjA4NDRaMAAwWTATBgcqhkjOPQIBBggqhkjOPQMBBwNCAAS6v/nm
XmVkQGMfqDpEq6aiV/AnwcGAJBGTL/ixbDqCPD5crgrXaycLdbZqy8jYVA5uWfHh
Ps+5/8acn3cSSAc2o0wwSjATBgNVHSUEDDAKBggrBgEFBQcDATAfBgNVHSMEGDAW
gBSYSzYwHNQsGiZXSVYDs59w3+UYNzASBgNVHREBAf8ECDAGhwR/AAABMAoGCCqG
SM49BAMCA0kAMEYCIQDkCDZP2InFWBBazaVJZlIwMz/o2cm3K7xaPbVucHPuswIh
AJstcTQ/RjJKhfZQo7mOIHO+l5U0TeInMCYg9XEPcNJt
-----END CERTIFICATE-----
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
-----BEGIN PRIVATE KEY-----
MIGHAgEAMBMGByqGSM49AgEGCCqGSM49AwEHBG0wawIBAQQgjHE1FFYDxseFqNrC
jjh72BLj5tHTh5vIMcdn0w3W1PKhRANCAAS6v/nmXmVkQGMfqDpEq6aiV/AnwcGA
JBGTL/ixbDqCPD5crgrXaycLdbZqy8jYVA5uWfHhPs+5/8acn3cSSAc2
-----END PRIVATE KEY-----
6 changes: 0 additions & 6 deletions pkg/server/plugin/keymanager/hashicorpvault/vault_client.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,12 +36,6 @@ const (
defaultK8sMountPoint = "kubernetes"
)

type cloudKeyManagementService interface {
CreateKey(ctx context.Context, spireKeyID string, keyType TransitKeyType) error
GetKey(ctx context.Context, spireKeyID string) (string, error)
SignData(ctx context.Context, spireKeyID string, data []byte, hashAlgo TransitHashAlgorithm, signatureAlgo TransitSignatureAlgorithm) ([]byte, error)
}

type AuthMethod int

const (
Expand Down
185 changes: 185 additions & 0 deletions pkg/server/plugin/keymanager/hashicorpvault/vault_client_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,185 @@
package hashicorpvault

import (
"fmt"
"testing"
"time"

"github.com/hashicorp/go-hclog"
"github.com/hashicorp/vault/sdk/helper/consts"
"github.com/spiffe/spire/test/spiretest"
"github.com/stretchr/testify/require"
"google.golang.org/grpc/codes"
)

const (
testRootCert = "testdata/root-cert.pem"
testServerCert = "testdata/server-cert.pem"
testServerKey = "testdata/server-key.pem"
)

func TestNewClientConfigWithDefaultValues(t *testing.T) {
p := &ClientParams{
VaultAddr: "http://example.org:8200/",
PKIMountPoint: "", // Expect the default value to be used.
Token: "test-token",
CertAuthMountPoint: "", // Expect the default value to be used.
AppRoleAuthMountPoint: "", // Expect the default value to be used.
K8sAuthMountPoint: "", // Expect the default value to be used.
}

cc, err := NewClientConfig(p, hclog.Default())
require.NoError(t, err)
require.Equal(t, defaultPKIMountPoint, cc.clientParams.PKIMountPoint)
require.Equal(t, defaultCertMountPoint, cc.clientParams.CertAuthMountPoint)
require.Equal(t, defaultAppRoleMountPoint, cc.clientParams.AppRoleAuthMountPoint)
require.Equal(t, defaultK8sMountPoint, cc.clientParams.K8sAuthMountPoint)
}

func TestNewClientConfigWithGivenValuesInsteadOfDefaults(t *testing.T) {
p := &ClientParams{
VaultAddr: "http://example.org:8200/",
PKIMountPoint: "test-pki",
Token: "test-token",
CertAuthMountPoint: "test-tls-cert",
AppRoleAuthMountPoint: "test-approle",
K8sAuthMountPoint: "test-k8s",
}

cc, err := NewClientConfig(p, hclog.Default())
require.NoError(t, err)
require.Equal(t, "test-pki", cc.clientParams.PKIMountPoint)
require.Equal(t, "test-tls-cert", cc.clientParams.CertAuthMountPoint)
require.Equal(t, "test-approle", cc.clientParams.AppRoleAuthMountPoint)
require.Equal(t, "test-k8s", cc.clientParams.K8sAuthMountPoint)
}

func TestNewAuthenticatedClientTokenAuth(t *testing.T) {
fakeVaultServer := newFakeVaultServer()
fakeVaultServer.LookupSelfResponseCode = 200
for _, tt := range []struct {
name string
token string
response []byte
renew bool
namespace string
expectCode codes.Code
expectMsgPrefix string
}{
{
name: "Token Authentication success / Token never expire",
token: "test-token",
response: []byte(testLookupSelfResponseNeverExpire),
renew: true,
},
{
name: "Token Authentication success / Token is renewable",
token: "test-token",
response: []byte(testLookupSelfResponse),
renew: true,
},
{
name: "Token Authentication success / Token is not renewable",
token: "test-token",
response: []byte(testLookupSelfResponseNotRenewable),
},
{
name: "Token Authentication success / Token is renewable / Namespace is given",
token: "test-token",
response: []byte(testCertAuthResponse),
renew: true,
namespace: "test-ns",
},
{
name: "Token Authentication error / Token is empty",
token: "",
response: []byte(testCertAuthResponse),
renew: true,
namespace: "test-ns",
expectCode: codes.InvalidArgument,
expectMsgPrefix: "token is empty",
},
} {
tt := tt
t.Run(tt.name, func(t *testing.T) {
fakeVaultServer.LookupSelfResponse = tt.response

s, addr, err := fakeVaultServer.NewTLSServer()
require.NoError(t, err)

s.Start()
defer s.Close()

cp := &ClientParams{
VaultAddr: fmt.Sprintf("https://%v/", addr),
Namespace: tt.namespace,
CACertPath: testRootCert,
Token: tt.token,
}
cc, err := NewClientConfig(cp, hclog.Default())
require.NoError(t, err)

renewCh := make(chan struct{})
client, err := cc.NewAuthenticatedClient(TOKEN, renewCh)
if tt.expectMsgPrefix != "" {
spiretest.RequireGRPCStatusHasPrefix(t, err, tt.expectCode, tt.expectMsgPrefix)
return
}

require.NoError(t, err)

select {
case <-renewCh:
require.Equal(t, false, tt.renew)
default:
require.Equal(t, true, tt.renew)
}

if cp.Namespace != "" {
headers := client.vaultClient.Headers()
require.Equal(t, cp.Namespace, headers.Get(consts.NamespaceHeaderName))
}
})
}
}

func TestRenewTokenFailed(t *testing.T) {
fakeVaultServer := newFakeVaultServer()
fakeVaultServer.LookupSelfResponse = []byte(testLookupSelfResponseShortTTL)
fakeVaultServer.LookupSelfResponseCode = 200
fakeVaultServer.RenewResponse = []byte("fake renew error")
fakeVaultServer.RenewResponseCode = 500

s, addr, err := fakeVaultServer.NewTLSServer()
require.NoError(t, err)

s.Start()
defer s.Close()

retry := 0
cp := &ClientParams{
MaxRetries: &retry,
VaultAddr: fmt.Sprintf("https://%v/", addr),
CACertPath: testRootCert,
Token: "test-token",
}
cc, err := NewClientConfig(cp, hclog.Default())
require.NoError(t, err)

renewCh := make(chan struct{})
_, err = cc.NewAuthenticatedClient(TOKEN, renewCh)
require.NoError(t, err)

select {
case <-renewCh:
case <-time.After(1 * time.Second):
t.Error("renewChan did not close in the expected time")
}
}

func newFakeVaultServer() *FakeVaultServerConfig {
fakeVaultServer := NewFakeVaultServerConfig()
fakeVaultServer.RenewResponseCode = 200
fakeVaultServer.RenewResponse = []byte(testRenewResponse)
return fakeVaultServer
}
Loading

0 comments on commit bb9fb10

Please sign in to comment.