Skip to content

Commit

Permalink
Move to new spire trust source design with state tracking
Browse files Browse the repository at this point in the history
  • Loading branch information
Peyton Walters committed Jan 8, 2020
1 parent 7ea5c19 commit 7d546c6
Show file tree
Hide file tree
Showing 5 changed files with 75 additions and 78 deletions.
2 changes: 1 addition & 1 deletion cmd/plugin/vault-auth-spire.go
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,7 @@ func BackendFactory(ctx context.Context, backendConfig *logical.BackendConfig) (
spirePlugin.verifier.AddTrustSource(trustSource)
}
if settings.SourceOfTrust.Spire != nil {
trustSource, err := common.NewSpireTrustSource(settings.SourceOfTrust.Spire.SpireEndpointUrls, settings.SourceOfTrust.Spire.LocalBackupPath)
trustSource, err := common.NewSpireTrustSource(settings.SourceOfTrust.Spire.SpireEndpointURLs, settings.SourceOfTrust.Spire.LocalBackupPath)
if err != nil {
return nil, errors.New("vault-auth-spire: Failed to initialize spire TrustSource - " + err.Error())
}
Expand Down
31 changes: 0 additions & 31 deletions internal/common/filetrustsource.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,9 @@ package common

import (
"crypto/x509"
"encoding/pem"
"errors"
"fmt"
"io/ioutil"
"os"
"strings"

"github.com/sirupsen/logrus"
)
Expand Down Expand Up @@ -61,34 +58,6 @@ func (source *FileTrustSource) TrustedCertificates() map[string][]*x509.Certific
return source.domainCertificates
}

func (source *FileTrustSource) updateCertificates(certs []*x509.Certificate, spiffeID, path string) error {
builder := strings.Builder{}
for _, cert := range certs {
block := &pem.Block{
Type: "CERTIFICATE",
Bytes: cert.Raw,
}
builder.Write(pem.EncodeToMemory(block))
}

file, err := appFS.OpenFile(path, os.O_WRONLY|os.O_CREATE, 0600)
if err != nil {
return err
}
defer file.Close()
_, err = file.WriteString(builder.String())
if err != nil {
return err
}

err = source.loadDomain(spiffeID)
if err != nil {
return err
}

return nil
}

// For each domain/file mapping found in source.domainPaths, load the PEM and read all
// certificates from the file.
func (source *FileTrustSource) loadCertificates() error {
Expand Down
10 changes: 5 additions & 5 deletions internal/common/settings.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ type FileTrustSourceSettings struct {
}

type SpireTrustSourceSettings struct {
SpireEndpointUrls map[string]string
SpireEndpointURLs map[string]string
LocalBackupPath string
}

Expand Down Expand Up @@ -143,14 +143,14 @@ func readSpireSourceOfTrustSettings() (*SpireTrustSourceSettings, error) {
return nil, errors.New("trustsource.spire.domains is required but not found")
}

viper.SetDefault("trustsource.spire.certLocation", "/var/run/spire/certs/")
viper.SetDefault("trustsource.spire.backupPath", "/var/run/spire/certs/")
viper.SetDefault("trustsource.spire.storeEnabled", true)
spireSettings := &SpireTrustSourceSettings{
SpireEndpoints: viper.GetStringMapString("trustsource.spire.domains"),
CertStorePath: viper.GetString("trustsource.spire.certLocation"),
SpireEndpointURLs: viper.GetStringMapString("trustsource.spire.domains"),
LocalBackupPath: viper.GetString("trustsource.spire.backupPath"),
}
if !viper.GetBool("trustsource.spire.storeEnabled") {
spireSettings.CertStorePath = ""
spireSettings.LocalBackupPath = ""
}

return spireSettings, nil
Expand Down
107 changes: 66 additions & 41 deletions internal/common/spiretrustsource.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ package common

import (
"crypto/x509"
"errors"
"encoding/pem"
"fmt"
"os"
"regexp"
Expand All @@ -13,6 +13,7 @@ import (
"github.com/spiffe/go-spiffe/workload"
)

// SpireLoadState represents the current state of a Spire connection
type SpireLoadState int

const (
Expand All @@ -22,14 +23,16 @@ const (
Failed
)

// SpireTrustSource holds all necessary information to connect to a spire instance and store its certificates
// SpireEndpoint represents a single trust domain and its associated spire server connection
type SpireEndpoint struct {
domain string
loadState SpireLoadState
spireUrl string
spireURL string

client *workload.X509SVIDClient
}

// SpireTrustSource holds all necessary information to connect to a spire instance and store its certificates
type SpireTrustSource struct {
spireEndpoints map[string]*SpireEndpoint

Expand Down Expand Up @@ -57,7 +60,7 @@ func (s *SpireTrustSource) TrustedCertificates() map[string][]*x509.Certificate
}

// NewSpireTrustSource creates a new trust source with connectivity to one or more spire instances.
func NewSpireTrustSource(spireEndpointUrls map[string]string, localBackupDir string) (*SpireTrustSource, error) {
func NewSpireTrustSource(spireEndpointURLs map[string]string, localBackupDir string) (*SpireTrustSource, error) {
source := &SpireTrustSource{
spireEndpoints: make(map[string]*SpireEndpoint, 0),
localBackupDir: localBackupDir,
Expand All @@ -66,35 +69,44 @@ func NewSpireTrustSource(spireEndpointUrls map[string]string, localBackupDir str
updateTimeout: 5 * time.Second,
}

for domain, spireUrl := range spireEndpointUrls {
re := regexp.MustCompile(`^spiffe://(\S+)$`)
for trustDomain, spireURL := range spireEndpointURLs {
localStoragePath := ""
if localBackupDir != "" {
localStoragePath = "" // TODO: pull out from domain
matches := re.FindStringSubmatch(trustDomain)
if len(matches) < 2 {
return nil, fmt.Errorf("expected domain of form spiffe://<trust_domain> but got %s", trustDomain)
}
domain := matches[1]
if strings.Contains(domain, "/") {
return nil, fmt.Errorf("expected domain without slash but got %s", domain)
}
localStoragePath = localBackupDir + "/" + domain + ".pem"
}

client, err := workload.NewX509SVIDClient(
&workloadWatcher{
domain: domain,
domain: trustDomain,
source: source,
localStoragePath: localStoragePath,
},
workload.WithAddr(spireUrl),
workload.WithAddr(spireURL),
)
if err != nil {
return nil, fmt.Errorf("failed to construct a new NewX509SVIDClient for %s - %v", spireUrl, err)
return nil, fmt.Errorf("failed to construct a new NewX509SVIDClient for %s: %v", spireURL, err)
}

source.spireEndpoints[domain] = &SpireEndpoint{
domain: domain,
source.spireEndpoints[trustDomain] = &SpireEndpoint{
domain: trustDomain,
loadState: Pending,
spireUrl: spireUrl,
spireURL: spireURL,
client: client,
}
}

for _, spireEndpoint := range source.spireEndpoints {
if err := spireEndpoint.client.Start(); err != nil {
return nil, fmt.Errorf("failed to start NewX509SVIDClient for domain %s - %v", spireEndpoint.domain, err)
return nil, fmt.Errorf("failed to start NewX509SVIDClient for domain %s: %v", spireEndpoint.domain, err)
}
}

Expand All @@ -106,7 +118,7 @@ func (s *SpireTrustSource) Stop() error {
errs := make([]string, 0)
for _, spireEndpoint := range s.spireEndpoints {
if err := spireEndpoint.client.Stop(); err != nil {
errs = append(errs, fmt.Sprintf("domain %s - %v", spireEndpoint.domain, err))
errs = append(errs, fmt.Sprintf("domain %s: %v", spireEndpoint.domain, err))
}
}

Expand All @@ -118,28 +130,35 @@ func (s *SpireTrustSource) Stop() error {
}

func (w *workloadWatcher) UpdateX509SVIDs(svids *workload.X509SVIDs) {
certs := svids.Default().TrustBundle
w.source.domainCertificates[w.domain] = certs
w.source.spireEndpoints[w.domain].loadState = Loaded

if w.localStoragePath != "" {
builder := strings.Builder{}
for _, cert := range certs {
block := &pem.Block{
Type: "CERTIFICATE",
Bytes: cert.Raw,
}
builder.Write(pem.EncodeToMemory(block))
}
file, err := appFS.OpenFile(w.localStoragePath, os.O_WRONLY|os.O_CREATE, 0600)
if err != nil {
logrus.Warnf("could not open backup file for trust domain %s at %s: %v", w.domain, w.localStoragePath, err)
} else {
defer file.Close()
_, err = file.WriteString(builder.String())
if err != nil {
logrus.Warnf("could not write to backup file for trust domain %s at %s: %v", w.domain, w.localStoragePath, err)
}
}
}

// TODO:
// 1. Pull out certs from passed in svids and update w.source.domainCertificates
// 2. Update w.source.spireEndpoints[w.url].loadState = Loaded
// 3. write to w.source.updateChan
//
// if w.localStoragePath != ""
// 4. write certs to storage path from w.source.localBackupPath/[domain without spiffe:// part].pem

//certs := svids.Default().TrustBundle
//w.source.domainCertificates[w.uri] = certs
//if w.source.certLocation != "" {
// certPath := w.source.fileBacking.domainPaths[w.uri][0]
// err := w.source.fileBacking.updateCertificates(certs, w.uri, certPath)
// if err != nil {
// logrus.Warnf("error writing to cert file: %v", err)
// }
//}
//select {
//case w.source.updateChan <- struct{}{}:
//case <-time.After(w.source.updateTimeout):
//}
select {
case w.source.updateChan <- struct{}{}:
case <-time.After(w.source.updateTimeout):
}
}

func (w *workloadWatcher) OnError(err error) {
Expand All @@ -150,18 +169,24 @@ func (w *workloadWatcher) OnError(err error) {
}

if fileTrustSource, err := NewFileTrustSource(domainPaths); err != nil {
// TODO: log error
logrus.Warnf("could not load certs for domain %s from disk: %v", w.domain, err)
w.source.spireEndpoints[w.domain].loadState = Failed
} else {
// TODO:
// 1. Pull out certs from fileTrustSource and update w.source.domainCertificates
// 2. Update w.source.spireEndpoints[w.url].loadState = Loaded
w.source.domainCertificates[w.domain] = fileTrustSource.TrustedCertificates()[w.domain]
w.source.spireEndpoints[w.domain].loadState = LoadedFromBackup
logrus.Infof("loaded certs for domain %s from disk", w.domain)
}
} else {
// TODO: log error
w.source.spireEndpoints[w.domain].loadState = Failed
logrus.Warnf("could not connect to spire server for domain %s and local storage disabled", w.domain)
}
} else {
// if the state was already Loaded, LoadedFromBackup, or Failed then don't do anything
w.source.spireEndpoints[w.domain].loadState = Failed
}

// if the state was already Loaded, LoadedFromBackup, or Failed then don't do anything
select {
case w.source.updateChan <- struct{}{}:
case <-time.After(w.source.updateTimeout):
}
}
3 changes: 3 additions & 0 deletions internal/common/spiretrustsource_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,8 @@ func TestInitalLoad(t *testing.T) {
require.NoError(t, err)
defer source.Stop()

source.waitForUpdate(t)

certs := source.TrustedCertificates()["spiffe://example.org"]
require.Len(t, certs, 1)
assert.Equal(t, "US", certs[0].Subject.Country[0])
Expand Down Expand Up @@ -90,6 +92,7 @@ func TestWriteCerts(t *testing.T) {
newSource, err := NewSpireTrustSource(map[string]string{
"spiffe://example.org": dummyWorkloadAPI.Addr(),
}, "certs/")
newSource.waitForUpdate(t)
assert.Equal(t, ca.Roots(), newSource.TrustedCertificates()["spiffe://example.org"])
}

Expand Down

0 comments on commit 7d546c6

Please sign in to comment.