Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Spire trustsource #14

Open
wants to merge 14 commits into
base: develop
Choose a base branch
from
35 changes: 23 additions & 12 deletions internal/common/spiretrustsource.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import (
"encoding/pem"
"fmt"
"os"
"path/filepath"
"regexp"
"strings"
"time"
Expand Down Expand Up @@ -42,10 +43,7 @@ type SpireTrustSource struct {
updateTimeout time.Duration

domainCertificates map[string][]*x509.Certificate
}

type certMap struct {
Certs map[string][]string `json:"certs"`
testing bool
}

type workloadWatcher struct {
Expand Down Expand Up @@ -81,7 +79,7 @@ func NewSpireTrustSource(spireEndpointURLs map[string]string, localBackupDir str
if strings.Contains(domain, "/") {
return nil, fmt.Errorf("expected domain without slash but got %s", domain)
}
localStoragePath = localBackupDir + "/" + domain + ".pem"
localStoragePath = filepath.Join(localBackupDir, domain+".pem")
}

client, err := workload.NewX509SVIDClient(
Expand Down Expand Up @@ -113,6 +111,16 @@ func NewSpireTrustSource(spireEndpointURLs map[string]string, localBackupDir str
return source, nil
}

// NewSpireTestSource creates a new spire trust source with the test flag on.
func NewSpireTestSource(spireEndpointURLs map[string]string, localBackupDir string) (*SpireTrustSource, error) {
ts, err := NewSpireTrustSource(spireEndpointURLs, localBackupDir)
if err != nil {
return nil, err
}
ts.testing = true
return ts, nil
}

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't love this. I'm going to give some thought to a better way to accomplish what's needed.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK. I would love to use method overriding to add an optional testing flag, but Go doesn't have overriding.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Another common pattern is to pass in a config struct that can take on default values instead of using individual parameters.

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You could also pass an optional configuration function ala https://godoc.org/github.com/pkg/term#CBreakMode

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@dennisgove seemed to have the main problem of "we don't want to have test flags in prod code", so I don't think extra configuration functions would fix that problem. I'm really not sure how to fix that problem other than making a whole new type of watcher for testing purposes, but I feel like that defeats the purpose of the tests.

// Stop stops all spire clients
func (s *SpireTrustSource) Stop() error {
errs := make([]string, 0)
Expand Down Expand Up @@ -155,9 +163,11 @@ func (w *workloadWatcher) UpdateX509SVIDs(svids *workload.X509SVIDs) {
}
}

select {
case w.source.updateChan <- struct{}{}:
case <-time.After(w.source.updateTimeout):
if w.source.testing {
select {
case w.source.updateChan <- struct{}{}:
case <-time.After(w.source.updateTimeout):
}
}
}

Expand All @@ -182,11 +192,12 @@ func (w *workloadWatcher) OnError(err error) {
}
} else {
// if the state was already Loaded, LoadedFromBackup, or Failed then don't do anything
w.source.spireEndpoints[w.domain].loadState = Failed
}

select {
case w.source.updateChan <- struct{}{}:
case <-time.After(w.source.updateTimeout):
if w.source.testing {
select {
case w.source.updateChan <- struct{}{}:
case <-time.After(w.source.updateTimeout):
}
}
}
18 changes: 9 additions & 9 deletions internal/common/spiretrustsource_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,15 +30,15 @@ func setX509SVIDResponse(api *spiffetest.WorkloadAPI, ca *spiffetest.CA, svid []
api.SetX509SVIDResponse(response)
}

func TestInitalLoad(t *testing.T) {
func TestInitialLoad(t *testing.T) {
appFS = afero.NewMemMapFs()

afero.WriteFile(appFS, "certs/example.org.pem", []byte(leafCert), 600)

workloadAPI := spiffetest.NewWorkloadAPI(t, nil)
defer workloadAPI.Stop()

source, err := NewSpireTrustSource(map[string]string{
source, err := NewSpireTestSource(map[string]string{
"spiffe://example.org": workloadAPI.Addr(),
}, "certs/")
require.NoError(t, err)
Expand All @@ -54,14 +54,14 @@ func TestInitalLoad(t *testing.T) {
}

func TestInvalidURI(t *testing.T) {
_, err := NewSpireTrustSource(map[string]string{
_, err := NewSpireTestSource(map[string]string{
"spirffe://example.org": "",
}, "certs/")
require.Error(t, err)
}

func TestInvalidDomain(t *testing.T) {
_, err := NewSpireTrustSource(map[string]string{
_, err := NewSpireTestSource(map[string]string{
"spiffe://example.org/test": "",
}, "certs/")
require.Error(t, err)
Expand All @@ -78,7 +78,7 @@ func TestWriteCerts(t *testing.T) {

setX509SVIDResponse(workloadAPI, ca, svidFoo, keyFoo)

source, err := NewSpireTrustSource(map[string]string{
source, err := NewSpireTestSource(map[string]string{
"spiffe://example.org": workloadAPI.Addr(),
}, "certs/")
require.NoError(t, err)
Expand All @@ -89,7 +89,7 @@ func TestWriteCerts(t *testing.T) {
dummyWorkloadAPI := spiffetest.NewWorkloadAPI(t, nil)
defer dummyWorkloadAPI.Stop()

newSource, err := NewSpireTrustSource(map[string]string{
newSource, err := NewSpireTestSource(map[string]string{
"spiffe://example.org": dummyWorkloadAPI.Addr(),
}, "certs/")
newSource.waitForUpdate(t)
Expand All @@ -109,7 +109,7 @@ func TestSpireOverwrite(t *testing.T) {

setX509SVIDResponse(workloadAPI, ca, svidFoo, keyFoo)

source, err := NewSpireTrustSource(map[string]string{
source, err := NewSpireTestSource(map[string]string{
"spiffe://example.org": workloadAPI.Addr(),
}, "certs/")
require.NoError(t, err)
Expand All @@ -119,7 +119,7 @@ func TestSpireOverwrite(t *testing.T) {
assert.Equal(t, ca.Roots(), source.TrustedCertificates()["spiffe://example.org"])
}

func TestSpireReload(t *testing.T) {
func TestSpireRotation(t *testing.T) {
appFS = afero.NewMemMapFs()

workloadAPI := spiffetest.NewWorkloadAPI(t, nil)
Expand All @@ -129,7 +129,7 @@ func TestSpireReload(t *testing.T) {
svidFoo, keyFoo := ca.CreateX509SVID("spiffe://example.org/foo")
setX509SVIDResponse(workloadAPI, ca, svidFoo, keyFoo)

source, err := NewSpireTrustSource(map[string]string{
source, err := NewSpireTestSource(map[string]string{
"spiffe://example.org": workloadAPI.Addr(),
}, "")
require.NoError(t, err)
Expand Down