Skip to content

Latest commit

 

History

History
221 lines (183 loc) · 5.39 KB

saved_notes.md

File metadata and controls

221 lines (183 loc) · 5.39 KB

Saved Useful Notes

Does SVID Validation against Spire

Saving for posterity

package main

import (
	"context"
	"crypto/x509"
	"encoding/pem"
	"flag"
	"github.com/hashicorp/vault/sdk/framework"
	"github.com/hashicorp/vault/sdk/logical"
	//"github.com/spiffe/go-spiffe/spiffe"

	//"github.com/spiffe/go-spiffe/workload"
	"strings"

	// "context"
	// "crypto/subtle"
	// "errors"
	"log"
	"os"
	// "time"

	"github.com/hashicorp/vault/api"
	// "github.com/hashicorp/vault/sdk/framework"
	// "github.com/hashicorp/vault/sdk/logical"
	"github.com/hashicorp/vault/sdk/plugin"
)

func main() {
	standardVaultPluginInit()
}

var pluginSettingsFilePath string

func standardVaultPluginInit(){
	// This is all standard Vault auth plugin initialization stuff
	apiClientMeta := &api.PluginAPIClientMeta{}
	apiStandardFlags := apiClientMeta.FlagSet()
	apiStandardFlags.Parse(os.Args[1:])


	settingsFlags := flag.NewFlagSet("vault-auth-spire flags", flag.ContinueOnError)
	settingsFlags.StringVar(&pluginSettingsFilePath, "settings-file", "", "Path to plugin settings")
	settingsFlags.Parse(os.Args[1:])

	tlsConfig := apiClientMeta.GetTLSConfig()
	tlsProviderFunc := api.VaultPluginTLSProvider(tlsConfig)

	if err := plugin.Serve(&plugin.ServeOpts{
		BackendFactoryFunc: BackendFactory,
		TLSProviderFunc:    tlsProviderFunc,
	}); err != nil {
		log.Fatal(err)
	}
}

func BackendFactory(ctx context.Context, c *logical.BackendConfig) (logical.Backend, error) {

	b := Backend(c)

	if err := b.Setup(ctx, c); err != nil {
		return nil, err
	}
	return b, nil
}

type backend struct {
	*framework.Backend
	//svidWatcher *testWatcher
	//svidClient *workload.X509SVIDClient

	logger *log.Logger
}

func Backend(c *logical.BackendConfig) *backend {
	var b backend

	f, err := os.OpenFile("/tmp/vault-auth-spire.log",
		os.O_APPEND|os.O_CREATE|os.O_WRONLY, 0644)
	if err != nil {
		log.Println(err)
	}

	b.logger = log.New(f, "vault-auth-spire: ", log.LstdFlags)
	b.logger.Println("Logger has started")

	b.logger.Println("The settings file path is " + pluginSettingsFilePath)

	b.Backend = &framework.Backend{
		BackendType: logical.TypeCredential,
		AuthRenew:   b.pathAuthRenew,
		PathsSpecial: &logical.Paths{
			Unauthenticated: []string{"login"},
		},
		Paths: []*framework.Path{
			&framework.Path{
				Pattern: "login",
				Fields: map[string]*framework.FieldSchema{
					"svid": &framework.FieldSchema{
						Type: framework.TypeString,
					},
				},
				Operations: map[logical.Operation]framework.OperationHandler{
					logical.UpdateOperation: &framework.PathOperation{
						Callback:    b.pathAuthLogin,
						Summary:     "Login via Spiffe/Spire SVID",
					},
				},
			},
		},
	}

	//b.svidWatcher = newTestWatcher()
	//b.svidClient, _ = workload.NewX509SVIDClient(b.svidWatcher, workload.WithAddr("unix:///tmp/agent.sock"))
	//b.svidClient.Start()
	//
	//b.logger.Println("Plugin has been configured and svidClient started")

	return &b
}

func (b *backend) pathAuthLogin(_ context.Context, req *logical.Request, d *framework.FieldData) (*logical.Response, error) {
	svid := d.Get("svid").(string)

	if len(svid) <= 0 {
		return nil, logical.ErrPermissionDenied
	}

	b.logger.Println("I got svid " + svid)

	block, _ := pem.Decode([]byte(svid))
	if block == nil {
		b.logger.Println("failed to parse certificate PEM")
	}
	svidCert, err := x509.ParseCertificate(block.Bytes)
	if err != nil {
		b.logger.Println("failed to parse certificate: " + err.Error())
	}
	b.logger.Println("I created a valid x509.Certificate out of the pem")

	//var trustCertPool = x509.NewCertPool()
	//for _,cert := range b.svidWatcher.TrustBundle{
	//	trustCertPool.AddCert(cert)
	//}
	//b.logger.Println("I created certPool")

	//var certPoolMap = make(map[string]*x509.CertPool)
	//certPoolMap["spiffe://dev.bloomberg.com"] = trustCertPool
	//b.logger.Println("I created certPoolMap")

	//_, err = spiffe.VerifyPeerCertificate([]*x509.Certificate{svidCert}, certPoolMap, spiffe.ExpectAnyPeer())
	//b.logger.Println("I called VerifyPeerCertificate")

	uris := []string{}
	for _, uri := range svidCert.URIs{
		b.logger.Println("Found URI: " + uri.String())
		uris = append(uris, uri.String())
	}

	var result string
	ifnil != err{
		b.logger.Println("There was an error: " + err.Error())
		result = "There was an error: " + err.Error()
	} else{
		b.logger.Println("The cert was verified")
		result = "We've been verified and I found URIs: " + strings.Join(uris, ",")
	}

	// Compose the response
	return &logical.Response{
		Auth: &logical.Auth{
			InternalData: map[string]interface{}{
				"receivedSvid": svid,
			},
			Policies: []string{
				//"Trust Bundles: " + strconv.Itoa(len(b.svidWatcher.TrustBundle)),
				"Result: " + result,
			},
			Metadata: map[string]string{
				"spiffeId": uris[0],
			},
			LeaseOptions: logical.LeaseOptions{
				Renewable: false,
			},
		},
	}, nil
}

func (b *backend) pathAuthRenew(ctx context.Context, req *logical.Request, d *framework.FieldData) (*logical.Response, error) {
	// TODO
	return nil, nil
}

//type testWatcher struct {
//	TrustBundle  []*x509.Certificate
//	Errors       []error
//}
//
//func newTestWatcher() *testWatcher {
//	return &testWatcher{
//		//updateSignal: make(chan struct{}, 100),
//		//timeout:      10 * time.Second,
//	}
//}
//
//func (w *testWatcher) UpdateX509SVIDs(u *workload.X509SVIDs) {
//	if len(u.SVIDs) > 0 {
//		w.TrustBundle = u.SVIDs[0].TrustBundle
//	}
//}
//
//func (w *testWatcher) OnError(err error) {
//	w.Errors = append(w.Errors, err)
//}