From 7d546c6816dcd860d09c89147dcc7b94f61f0214 Mon Sep 17 00:00:00 2001 From: Peyton Walters Date: Wed, 8 Jan 2020 17:49:40 -0500 Subject: [PATCH] Move to new spire trust source design with state tracking --- cmd/plugin/vault-auth-spire.go | 2 +- internal/common/filetrustsource.go | 31 ------- internal/common/settings.go | 10 +-- internal/common/spiretrustsource.go | 107 ++++++++++++++--------- internal/common/spiretrustsource_test.go | 3 + 5 files changed, 75 insertions(+), 78 deletions(-) diff --git a/cmd/plugin/vault-auth-spire.go b/cmd/plugin/vault-auth-spire.go index 786b1d1..a258283 100644 --- a/cmd/plugin/vault-auth-spire.go +++ b/cmd/plugin/vault-auth-spire.go @@ -114,7 +114,7 @@ func BackendFactory(ctx context.Context, backendConfig *logical.BackendConfig) ( spirePlugin.verifier.AddTrustSource(trustSource) } if settings.SourceOfTrust.Spire != nil { - trustSource, err := common.NewSpireTrustSource(settings.SourceOfTrust.Spire.SpireEndpointUrls, settings.SourceOfTrust.Spire.LocalBackupPath) + trustSource, err := common.NewSpireTrustSource(settings.SourceOfTrust.Spire.SpireEndpointURLs, settings.SourceOfTrust.Spire.LocalBackupPath) if err != nil { return nil, errors.New("vault-auth-spire: Failed to initialize spire TrustSource - " + err.Error()) } diff --git a/internal/common/filetrustsource.go b/internal/common/filetrustsource.go index 96de17b..b5498fd 100644 --- a/internal/common/filetrustsource.go +++ b/internal/common/filetrustsource.go @@ -18,12 +18,9 @@ package common import ( "crypto/x509" - "encoding/pem" "errors" "fmt" "io/ioutil" - "os" - "strings" "github.com/sirupsen/logrus" ) @@ -61,34 +58,6 @@ func (source *FileTrustSource) TrustedCertificates() map[string][]*x509.Certific return source.domainCertificates } -func (source *FileTrustSource) updateCertificates(certs []*x509.Certificate, spiffeID, path string) error { - builder := strings.Builder{} - for _, cert := range certs { - block := &pem.Block{ - Type: "CERTIFICATE", - Bytes: cert.Raw, - } - builder.Write(pem.EncodeToMemory(block)) - } - - file, err := appFS.OpenFile(path, os.O_WRONLY|os.O_CREATE, 0600) - if err != nil { - return err - } - defer file.Close() - _, err = file.WriteString(builder.String()) - if err != nil { - return err - } - - err = source.loadDomain(spiffeID) - if err != nil { - return err - } - - return nil -} - // For each domain/file mapping found in source.domainPaths, load the PEM and read all // certificates from the file. func (source *FileTrustSource) loadCertificates() error { diff --git a/internal/common/settings.go b/internal/common/settings.go index 0eeb9de..f4c0b4e 100644 --- a/internal/common/settings.go +++ b/internal/common/settings.go @@ -36,7 +36,7 @@ type FileTrustSourceSettings struct { } type SpireTrustSourceSettings struct { - SpireEndpointUrls map[string]string + SpireEndpointURLs map[string]string LocalBackupPath string } @@ -143,14 +143,14 @@ func readSpireSourceOfTrustSettings() (*SpireTrustSourceSettings, error) { return nil, errors.New("trustsource.spire.domains is required but not found") } - viper.SetDefault("trustsource.spire.certLocation", "/var/run/spire/certs/") + viper.SetDefault("trustsource.spire.backupPath", "/var/run/spire/certs/") viper.SetDefault("trustsource.spire.storeEnabled", true) spireSettings := &SpireTrustSourceSettings{ - SpireEndpoints: viper.GetStringMapString("trustsource.spire.domains"), - CertStorePath: viper.GetString("trustsource.spire.certLocation"), + SpireEndpointURLs: viper.GetStringMapString("trustsource.spire.domains"), + LocalBackupPath: viper.GetString("trustsource.spire.backupPath"), } if !viper.GetBool("trustsource.spire.storeEnabled") { - spireSettings.CertStorePath = "" + spireSettings.LocalBackupPath = "" } return spireSettings, nil diff --git a/internal/common/spiretrustsource.go b/internal/common/spiretrustsource.go index d5e85a1..6eff614 100644 --- a/internal/common/spiretrustsource.go +++ b/internal/common/spiretrustsource.go @@ -2,7 +2,7 @@ package common import ( "crypto/x509" - "errors" + "encoding/pem" "fmt" "os" "regexp" @@ -13,6 +13,7 @@ import ( "github.com/spiffe/go-spiffe/workload" ) +// SpireLoadState represents the current state of a Spire connection type SpireLoadState int const ( @@ -22,14 +23,16 @@ const ( Failed ) -// SpireTrustSource holds all necessary information to connect to a spire instance and store its certificates +// SpireEndpoint represents a single trust domain and its associated spire server connection type SpireEndpoint struct { domain string loadState SpireLoadState - spireUrl string + spireURL string client *workload.X509SVIDClient } + +// SpireTrustSource holds all necessary information to connect to a spire instance and store its certificates type SpireTrustSource struct { spireEndpoints map[string]*SpireEndpoint @@ -57,7 +60,7 @@ func (s *SpireTrustSource) TrustedCertificates() map[string][]*x509.Certificate } // NewSpireTrustSource creates a new trust source with connectivity to one or more spire instances. -func NewSpireTrustSource(spireEndpointUrls map[string]string, localBackupDir string) (*SpireTrustSource, error) { +func NewSpireTrustSource(spireEndpointURLs map[string]string, localBackupDir string) (*SpireTrustSource, error) { source := &SpireTrustSource{ spireEndpoints: make(map[string]*SpireEndpoint, 0), localBackupDir: localBackupDir, @@ -66,35 +69,44 @@ func NewSpireTrustSource(spireEndpointUrls map[string]string, localBackupDir str updateTimeout: 5 * time.Second, } - for domain, spireUrl := range spireEndpointUrls { + re := regexp.MustCompile(`^spiffe://(\S+)$`) + for trustDomain, spireURL := range spireEndpointURLs { localStoragePath := "" if localBackupDir != "" { - localStoragePath = "" // TODO: pull out from domain + matches := re.FindStringSubmatch(trustDomain) + if len(matches) < 2 { + return nil, fmt.Errorf("expected domain of form spiffe:// but got %s", trustDomain) + } + domain := matches[1] + if strings.Contains(domain, "/") { + return nil, fmt.Errorf("expected domain without slash but got %s", domain) + } + localStoragePath = localBackupDir + "/" + domain + ".pem" } client, err := workload.NewX509SVIDClient( &workloadWatcher{ - domain: domain, + domain: trustDomain, source: source, localStoragePath: localStoragePath, }, - workload.WithAddr(spireUrl), + workload.WithAddr(spireURL), ) if err != nil { - return nil, fmt.Errorf("failed to construct a new NewX509SVIDClient for %s - %v", spireUrl, err) + return nil, fmt.Errorf("failed to construct a new NewX509SVIDClient for %s: %v", spireURL, err) } - source.spireEndpoints[domain] = &SpireEndpoint{ - domain: domain, + source.spireEndpoints[trustDomain] = &SpireEndpoint{ + domain: trustDomain, loadState: Pending, - spireUrl: spireUrl, + spireURL: spireURL, client: client, } } for _, spireEndpoint := range source.spireEndpoints { if err := spireEndpoint.client.Start(); err != nil { - return nil, fmt.Errorf("failed to start NewX509SVIDClient for domain %s - %v", spireEndpoint.domain, err) + return nil, fmt.Errorf("failed to start NewX509SVIDClient for domain %s: %v", spireEndpoint.domain, err) } } @@ -106,7 +118,7 @@ func (s *SpireTrustSource) Stop() error { errs := make([]string, 0) for _, spireEndpoint := range s.spireEndpoints { if err := spireEndpoint.client.Stop(); err != nil { - errs = append(errs, fmt.Sprintf("domain %s - %v", spireEndpoint.domain, err)) + errs = append(errs, fmt.Sprintf("domain %s: %v", spireEndpoint.domain, err)) } } @@ -118,28 +130,35 @@ func (s *SpireTrustSource) Stop() error { } func (w *workloadWatcher) UpdateX509SVIDs(svids *workload.X509SVIDs) { + certs := svids.Default().TrustBundle + w.source.domainCertificates[w.domain] = certs + w.source.spireEndpoints[w.domain].loadState = Loaded + + if w.localStoragePath != "" { + builder := strings.Builder{} + for _, cert := range certs { + block := &pem.Block{ + Type: "CERTIFICATE", + Bytes: cert.Raw, + } + builder.Write(pem.EncodeToMemory(block)) + } + file, err := appFS.OpenFile(w.localStoragePath, os.O_WRONLY|os.O_CREATE, 0600) + if err != nil { + logrus.Warnf("could not open backup file for trust domain %s at %s: %v", w.domain, w.localStoragePath, err) + } else { + defer file.Close() + _, err = file.WriteString(builder.String()) + if err != nil { + logrus.Warnf("could not write to backup file for trust domain %s at %s: %v", w.domain, w.localStoragePath, err) + } + } + } - // TODO: - // 1. Pull out certs from passed in svids and update w.source.domainCertificates - // 2. Update w.source.spireEndpoints[w.url].loadState = Loaded - // 3. write to w.source.updateChan - // - // if w.localStoragePath != "" - // 4. write certs to storage path from w.source.localBackupPath/[domain without spiffe:// part].pem - - //certs := svids.Default().TrustBundle - //w.source.domainCertificates[w.uri] = certs - //if w.source.certLocation != "" { - // certPath := w.source.fileBacking.domainPaths[w.uri][0] - // err := w.source.fileBacking.updateCertificates(certs, w.uri, certPath) - // if err != nil { - // logrus.Warnf("error writing to cert file: %v", err) - // } - //} - //select { - //case w.source.updateChan <- struct{}{}: - //case <-time.After(w.source.updateTimeout): - //} + select { + case w.source.updateChan <- struct{}{}: + case <-time.After(w.source.updateTimeout): + } } func (w *workloadWatcher) OnError(err error) { @@ -150,18 +169,24 @@ func (w *workloadWatcher) OnError(err error) { } if fileTrustSource, err := NewFileTrustSource(domainPaths); err != nil { - // TODO: log error + logrus.Warnf("could not load certs for domain %s from disk: %v", w.domain, err) w.source.spireEndpoints[w.domain].loadState = Failed } else { - // TODO: - // 1. Pull out certs from fileTrustSource and update w.source.domainCertificates - // 2. Update w.source.spireEndpoints[w.url].loadState = Loaded + w.source.domainCertificates[w.domain] = fileTrustSource.TrustedCertificates()[w.domain] + w.source.spireEndpoints[w.domain].loadState = LoadedFromBackup + logrus.Infof("loaded certs for domain %s from disk", w.domain) } } else { - // TODO: log error w.source.spireEndpoints[w.domain].loadState = Failed + logrus.Warnf("could not connect to spire server for domain %s and local storage disabled", w.domain) } + } else { + // if the state was already Loaded, LoadedFromBackup, or Failed then don't do anything + w.source.spireEndpoints[w.domain].loadState = Failed } - // if the state was already Loaded, LoadedFromBackup, or Failed then don't do anything + select { + case w.source.updateChan <- struct{}{}: + case <-time.After(w.source.updateTimeout): + } } diff --git a/internal/common/spiretrustsource_test.go b/internal/common/spiretrustsource_test.go index a9e9638..84b3a2f 100644 --- a/internal/common/spiretrustsource_test.go +++ b/internal/common/spiretrustsource_test.go @@ -44,6 +44,8 @@ func TestInitalLoad(t *testing.T) { require.NoError(t, err) defer source.Stop() + source.waitForUpdate(t) + certs := source.TrustedCertificates()["spiffe://example.org"] require.Len(t, certs, 1) assert.Equal(t, "US", certs[0].Subject.Country[0]) @@ -90,6 +92,7 @@ func TestWriteCerts(t *testing.T) { newSource, err := NewSpireTrustSource(map[string]string{ "spiffe://example.org": dummyWorkloadAPI.Addr(), }, "certs/") + newSource.waitForUpdate(t) assert.Equal(t, ca.Roots(), newSource.TrustedCertificates()["spiffe://example.org"]) }