diff --git a/pkg/server/catalog/keymanager.go b/pkg/server/catalog/keymanager.go index 645570c34d7..13338bc96d4 100644 --- a/pkg/server/catalog/keymanager.go +++ b/pkg/server/catalog/keymanager.go @@ -2,6 +2,7 @@ package catalog import ( "github.com/spiffe/spire/pkg/common/catalog" + "github.com/spiffe/spire/pkg/server/plugin/keymanager/hashicorpvault" "github.com/spiffe/spire/pkg/server/plugin/keymanager" "github.com/spiffe/spire/pkg/server/plugin/keymanager/awskms" @@ -33,6 +34,7 @@ func (repo *keyManagerRepository) BuiltIns() []catalog.BuiltIn { disk.BuiltIn(), gcpkms.BuiltIn(), azurekeyvault.BuiltIn(), + hashicorpvault.BuiltIn(), memory.BuiltIn(), } } diff --git a/pkg/server/plugin/keymanager/hashicorpvault/hashicorp_vault.go b/pkg/server/plugin/keymanager/hashicorpvault/hashicorp_vault.go new file mode 100644 index 00000000000..2a207a7cfa4 --- /dev/null +++ b/pkg/server/plugin/keymanager/hashicorpvault/hashicorp_vault.go @@ -0,0 +1,310 @@ +package hashicorpvault + +import ( + "context" + "crypto/sha256" + "encoding/hex" + "encoding/pem" + "errors" + "github.com/Azure/azure-sdk-for-go/sdk/azcore/to" + "github.com/hashicorp/go-hclog" + "github.com/hashicorp/hcl" + keymanagerv1 "github.com/spiffe/spire-plugin-sdk/proto/spire/plugin/server/keymanager/v1" + configv1 "github.com/spiffe/spire-plugin-sdk/proto/spire/service/common/config/v1" + "github.com/spiffe/spire/pkg/common/catalog" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/status" + "os" + "sync" +) + +const ( + pluginName = "hashicorp_vault" +) + +func BuiltIn() catalog.BuiltIn { + return builtin(New()) +} + +func builtin(p *Plugin) catalog.BuiltIn { + return catalog.MakeBuiltIn(pluginName, + keymanagerv1.KeyManagerPluginServer(p), + configv1.ConfigServiceServer(p), + ) +} + +type keyEntry struct { + PublicKey *keymanagerv1.PublicKey +} + +type pluginHooks struct { + // Used for testing only. + scheduleDeleteSignal chan error + refreshKeysSignal chan error + disposeKeysSignal chan error + + lookupEnv func(string) (string, bool) +} + +// Config provides configuration context for the plugin. +type Config struct { + // A URL of Vault server. (e.g., https://vault.example.com:8443/) + VaultAddr string `hcl:"vault_addr" json:"vault_addr"` + + // Configuration for the Token authentication method + TokenAuth *TokenAuthConfig `hcl:"token_auth" json:"token_auth,omitempty"` + + // TODO: Support other auth methods + // TODO: Support client certificate and key +} + +type TokenAuthConfig struct { + // Token string to set into "X-Vault-Token" header + Token string `hcl:"token" json:"token"` +} + +// Plugin is the main representation of this keymanager plugin +type Plugin struct { + keymanagerv1.UnsafeKeyManagerServer + configv1.UnsafeConfigServer + + logger hclog.Logger + mu sync.RWMutex + entries map[string]keyEntry + + authMethod AuthMethod + cc *ClientConfig + vc *Client + + hooks pluginHooks +} + +// New returns an instantiated plugin. +func New() *Plugin { + return newPlugin() +} + +// newPlugin returns a new plugin instance. +func newPlugin() *Plugin { + return &Plugin{ + entries: make(map[string]keyEntry), + hooks: pluginHooks{ + lookupEnv: os.LookupEnv, + }, + } +} + +// SetLogger sets a logger +func (p *Plugin) SetLogger(log hclog.Logger) { + p.logger = log +} + +func (p *Plugin) Configure(_ context.Context, req *configv1.ConfigureRequest) (*configv1.ConfigureResponse, error) { + config := new(Config) + + if err := hcl.Decode(&config, req.HclConfiguration); err != nil { + return nil, status.Errorf(codes.InvalidArgument, "unable to decode configuration: %v", err) + } + + p.mu.Lock() + defer p.mu.Unlock() + + am, err := parseAuthMethod(config) + if err != nil { + return nil, err + } + cp, err := p.genClientParams(am, config) + if err != nil { + return nil, err + } + vcConfig, err := NewClientConfig(cp, p.logger) + if err != nil { + return nil, err + } + + p.authMethod = am + p.cc = vcConfig + + return &configv1.ConfigureResponse{}, nil +} + +func parseAuthMethod(config *Config) (AuthMethod, error) { + var authMethod AuthMethod + if config.TokenAuth != nil { + authMethod = TOKEN + } + + if authMethod != 0 { + return authMethod, nil + } + + return 0, status.Error(codes.InvalidArgument, "must be configured one of these authentication method 'Token'") +} + +func (p *Plugin) genClientParams(method AuthMethod, config *Config) (*ClientParams, error) { + cp := &ClientParams{ + VaultAddr: p.getEnvOrDefault(envVaultAddr, config.VaultAddr), + } + + switch method { + case TOKEN: + cp.Token = p.getEnvOrDefault(envVaultToken, config.TokenAuth.Token) + } + + return cp, nil +} + +func (p *Plugin) getEnvOrDefault(envKey, fallback string) string { + if value, ok := p.hooks.lookupEnv(envKey); ok { + return value + } + return fallback +} + +func (p *Plugin) GenerateKey(ctx context.Context, req *keymanagerv1.GenerateKeyRequest) (*keymanagerv1.GenerateKeyResponse, error) { + if req.KeyId == "" { + return nil, status.Error(codes.InvalidArgument, "key id is required") + } + if req.KeyType == keymanagerv1.KeyType_UNSPECIFIED_KEY_TYPE { + return nil, status.Error(codes.InvalidArgument, "key type is required") + } + + p.mu.Lock() + defer p.mu.Unlock() + + spireKeyID := req.KeyId + newKeyEntry, err := p.createKey(ctx, spireKeyID, req.KeyType) + if err != nil { + return nil, err + } + + p.entries[spireKeyID] = *newKeyEntry + + return &keymanagerv1.GenerateKeyResponse{ + PublicKey: newKeyEntry.PublicKey, + }, nil +} + +func (p *Plugin) SignData(ctx context.Context, req *keymanagerv1.SignDataRequest) (*keymanagerv1.SignDataResponse, error) { + return nil, errors.New("sign data is not implemented") +} + +func (p *Plugin) GetPublicKey(_ context.Context, req *keymanagerv1.GetPublicKeyRequest) (*keymanagerv1.GetPublicKeyResponse, error) { + if req.KeyId == "" { + return nil, status.Error(codes.InvalidArgument, "key id is required") + } + + p.mu.RLock() + defer p.mu.RUnlock() + + entry, ok := p.entries[req.KeyId] + if !ok { + return nil, status.Errorf(codes.NotFound, "key %q not found", req.KeyId) + } + + return &keymanagerv1.GetPublicKeyResponse{ + PublicKey: entry.PublicKey, + }, nil +} + +func (p *Plugin) GetPublicKeys(context.Context, *keymanagerv1.GetPublicKeysRequest) (*keymanagerv1.GetPublicKeysResponse, error) { + var keys = make([]*keymanagerv1.PublicKey, len(p.entries), 0) + + p.mu.RLock() + defer p.mu.RUnlock() + + for _, key := range p.entries { + keys = append(keys, key.PublicKey) + } + + p.logger.Debug("getting public keys") + + return &keymanagerv1.GetPublicKeysResponse{PublicKeys: keys}, nil +} + +func (p *Plugin) createKey(ctx context.Context, spireKeyID string, keyType keymanagerv1.KeyType) (*keyEntry, error) { + err := p.genVaultClient() + if err != nil { + return nil, err + } + + kt, err := convertToTransitKeyType(keyType) + if err != nil { + return nil, err + } + + s, err := p.vc.CreateKey(ctx, spireKeyID, *kt) + if err != nil { + return nil, err + } + + s, err = p.vc.GetKey(ctx, spireKeyID) + if err != nil { + return nil, err + } + + // TODO: Should we support multiple versions of the key? + keys := s.Data["keys"].(map[string]interface{}) + last := keys["1"].(map[string]interface{}) + encodedPub := []byte(last["public_key"].(string)) + + // TODO: Should I handle the rest somehow? + pemBlock, _ := pem.Decode(encodedPub) + if pemBlock == nil || pemBlock.Type != "PUBLIC KEY" { + return nil, status.Error(codes.Internal, "unable to decode PEM key") + } + + return &keyEntry{ + PublicKey: &keymanagerv1.PublicKey{ + Id: spireKeyID, + Type: keyType, + PkixData: pemBlock.Bytes, + Fingerprint: makeFingerprint(pemBlock.Bytes), + }, + }, nil +} + +func convertToTransitKeyType(keyType keymanagerv1.KeyType) (*TransitKeyType, error) { + switch keyType { + case keymanagerv1.KeyType_EC_P256: + return to.Ptr(TransitKeyType_ECDSA_P256), nil + case keymanagerv1.KeyType_EC_P384: + return to.Ptr(TransitKeyType_ECDSA_P384), nil + case keymanagerv1.KeyType_RSA_2048: + return to.Ptr(TransitKeyType_RSA_2048), nil + case keymanagerv1.KeyType_RSA_4096: + return to.Ptr(TransitKeyType_RSA_4096), nil + default: + return nil, status.Errorf(codes.Internal, "unsupported key type: %v", keyType) + } +} + +// TODO: Use context here (?) +// TODO: Should we really generate the client like this, relies on the fact that the mutex is already locked :( +func (p *Plugin) genVaultClient() error { + if p.vc == nil { + renewCh := make(chan struct{}) + vc, err := p.cc.NewAuthenticatedClient(p.authMethod, renewCh) + if err != nil { + return status.Errorf(codes.Internal, "failed to prepare authenticated client: %v", err) + } + p.vc = vc + + // if renewCh has been closed, the token can not be renewed and may expire, + // it needs to re-authenticate to the Vault. + go func() { + <-renewCh + p.mu.Lock() + defer p.mu.Unlock() + p.vc = nil + p.logger.Debug("Going to re-authenticate to the Vault during the next key manager operation") + }() + } + + return nil +} + +func makeFingerprint(pkixData []byte) string { + s := sha256.Sum256(pkixData) + return hex.EncodeToString(s[:]) +} diff --git a/pkg/server/plugin/keymanager/hashicorpvault/renewer.go b/pkg/server/plugin/keymanager/hashicorpvault/renewer.go new file mode 100644 index 00000000000..8f14ab5ab34 --- /dev/null +++ b/pkg/server/plugin/keymanager/hashicorpvault/renewer.go @@ -0,0 +1,50 @@ +package hashicorpvault + +import ( + "github.com/hashicorp/go-hclog" + vapi "github.com/hashicorp/vault/api" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/status" +) + +const ( + defaultRenewBehavior = vapi.RenewBehaviorIgnoreErrors +) + +type Renew struct { + logger hclog.Logger + watcher *vapi.LifetimeWatcher +} + +func NewRenew(client *vapi.Client, secret *vapi.Secret, logger hclog.Logger) (*Renew, error) { + watcher, err := client.NewLifetimeWatcher(&vapi.LifetimeWatcherInput{ + Secret: secret, + RenewBehavior: defaultRenewBehavior, + }) + if err != nil { + return nil, status.Errorf(codes.Internal, "failed to initialize Renewer: %v", err) + } + return &Renew{ + logger: logger, + watcher: watcher, + }, nil +} + +func (r *Renew) Run() { + go r.watcher.Start() + defer r.watcher.Stop() + + for { + select { + case err := <-r.watcher.DoneCh(): + if err != nil { + r.logger.Error("Failed to renew auth token", "err", err) + return + } + r.logger.Error("Failed to renew auth token. Retries may have exceeded the lease time threshold") + return + case renewal := <-r.watcher.RenewCh(): + r.logger.Debug("Successfully renew auth token", "request_id", renewal.Secret.RequestID, "lease_duration", renewal.Secret.Auth.LeaseDuration) + } + } +} diff --git a/pkg/server/plugin/keymanager/hashicorpvault/vault_client.go b/pkg/server/plugin/keymanager/hashicorpvault/vault_client.go new file mode 100644 index 00000000000..951665ad769 --- /dev/null +++ b/pkg/server/plugin/keymanager/hashicorpvault/vault_client.go @@ -0,0 +1,373 @@ +package hashicorpvault + +import ( + "crypto/tls" + "crypto/x509" + "fmt" + "github.com/hashicorp/go-hclog" + vapi "github.com/hashicorp/vault/api" + "github.com/imdario/mergo" + "golang.org/x/net/context" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/status" + "net/http" + "os" + + "github.com/spiffe/spire/pkg/common/pemutil" +) + +// TODO: Delete everything that is unused in here + +const ( + envVaultAddr = "VAULT_ADDR" + envVaultToken = "VAULT_TOKEN" + envVaultClientCert = "VAULT_CLIENT_CERT" + envVaultClientKey = "VAULT_CLIENT_KEY" + envVaultCACert = "VAULT_CACERT" + envVaultAppRoleID = "VAULT_APPROLE_ID" + envVaultAppRoleSecretID = "VAULT_APPROLE_SECRET_ID" // #nosec G101 + envVaultNamespace = "VAULT_NAMESPACE" + + defaultCertMountPoint = "cert" + defaultPKIMountPoint = "pki" + defaultAppRoleMountPoint = "approle" + defaultK8sMountPoint = "kubernetes" +) + +type AuthMethod int + +const ( + _ AuthMethod = iota + CERT + TOKEN + APPROLE + K8S +) + +// ClientConfig represents configuration parameters for vault client +type ClientConfig struct { + Logger hclog.Logger + // vault client parameters + clientParams *ClientParams +} + +type ClientParams struct { + // A URL of Vault server. (e.g., https://vault.example.com:8443/) + VaultAddr string + // Name of mount point where PKI secret engine is mounted. (e.e., //ca/pem ) + PKIMountPoint string + // token string to use when auth method is 'token' + Token string + // Name of mount point where TLS Cert auth method is mounted. (e.g., /auth//login ) + CertAuthMountPoint string + // Name of the Vault role. + // If given, the plugin authenticates against only the named role + CertAuthRoleName string + // Path to a client certificate file to be used when auth method is 'cert' + ClientCertPath string + // Path to a client private key file to be used when auth method is 'cert' + ClientKeyPath string + // Path to a CA certificate file to be used when client verifies a server certificate + CACertPath string + // Name of mount point where AppRole auth method is mounted. (e.g., /auth//login ) + AppRoleAuthMountPoint string + // An identifier of AppRole + AppRoleID string + // A credential set of AppRole + AppRoleSecretID string + // Name of the mount point where Kubernetes auth method is mounted. (e.g., /auth//login) + K8sAuthMountPoint string + // Name of the Vault role. + // The plugin authenticates against the named role. + K8sAuthRoleName string + // Path to a K8s Service Account Token to be used when auth method is 'k8s' + K8sAuthTokenPath string + // If true, client accepts any certificates. + // It should be used only test environment so on. + TLSSKipVerify bool + // MaxRetries controls the number of times to retry to connect + // Set to 0 to disable retrying. + // If the value is nil, to use the default in hashicorp/vault/api. + MaxRetries *int + // Name of the Vault namespace + Namespace string +} + +type Client struct { + vaultClient *vapi.Client + clientParams *ClientParams +} + +// SignCSRResponse includes certificates which are generates by Vault +type SignCSRResponse struct { + // A certificate requested to sign + CACertPEM string + // A certificate of CA(Vault) + UpstreamCACertPEM string + // Set of Upstream CA certificates + UpstreamCACertChainPEM []string +} + +// NewClientConfig returns a new *ClientConfig with default parameters. +func NewClientConfig(cp *ClientParams, logger hclog.Logger) (*ClientConfig, error) { + cc := &ClientConfig{ + Logger: logger, + } + defaultParams := &ClientParams{ + CertAuthMountPoint: defaultCertMountPoint, + AppRoleAuthMountPoint: defaultAppRoleMountPoint, + K8sAuthMountPoint: defaultK8sMountPoint, + PKIMountPoint: defaultPKIMountPoint, + } + if err := mergo.Merge(cp, defaultParams); err != nil { + return nil, status.Errorf(codes.Internal, "unable to merge client params: %v", err) + } + cc.clientParams = cp + return cc, nil +} + +// NewAuthenticatedClient returns a new authenticated vault client with given authentication method +func (c *ClientConfig) NewAuthenticatedClient(method AuthMethod, renewCh chan struct{}) (client *Client, err error) { + config := vapi.DefaultConfig() + config.Address = c.clientParams.VaultAddr + if c.clientParams.MaxRetries != nil { + config.MaxRetries = *c.clientParams.MaxRetries + } + + if err := c.configureTLS(config); err != nil { + return nil, err + } + vc, err := vapi.NewClient(config) + if err != nil { + return nil, status.Errorf(codes.Internal, "unable to create Vault client: %v", err) + } + + if c.clientParams.Namespace != "" { + vc.SetNamespace(c.clientParams.Namespace) + } + + client = &Client{ + vaultClient: vc, + clientParams: c.clientParams, + } + + var sec *vapi.Secret + switch method { + case TOKEN: + sec, err = client.LookupSelf(c.clientParams.Token) + if err != nil { + return nil, err + } + if sec == nil { + return nil, status.Error(codes.Internal, "lookup self response is nil") + } + case CERT: + path := fmt.Sprintf("auth/%v/login", c.clientParams.CertAuthMountPoint) + sec, err = client.Auth(path, map[string]any{ + "name": c.clientParams.CertAuthRoleName, + }) + if err != nil { + return nil, err + } + if sec == nil { + return nil, status.Error(codes.Internal, "tls cert authentication response is nil") + } + case APPROLE: + path := fmt.Sprintf("auth/%v/login", c.clientParams.AppRoleAuthMountPoint) + body := map[string]any{ + "role_id": c.clientParams.AppRoleID, + "secret_id": c.clientParams.AppRoleSecretID, + } + sec, err = client.Auth(path, body) + if err != nil { + return nil, err + } + if sec == nil { + return nil, status.Error(codes.Internal, "approle authentication response is nil") + } + case K8S: + b, err := os.ReadFile(c.clientParams.K8sAuthTokenPath) + if err != nil { + return nil, status.Errorf(codes.Internal, "failed to read k8s service account token: %v", err) + } + path := fmt.Sprintf("auth/%s/login", c.clientParams.K8sAuthMountPoint) + body := map[string]any{ + "role": c.clientParams.K8sAuthRoleName, + "jwt": string(b), + } + sec, err = client.Auth(path, body) + if err != nil { + return nil, err + } + if sec == nil { + return nil, status.Error(codes.Internal, "k8s authentication response is nil") + } + } + + err = handleRenewToken(vc, sec, renewCh, c.Logger) + if err != nil { + return nil, err + } + + return client, nil +} + +// handleRenewToken handles renewing the vault token. +// if the token is non-renewable or renew failed, renewCh will be closed. +func handleRenewToken(vc *vapi.Client, sec *vapi.Secret, renewCh chan struct{}, logger hclog.Logger) error { + if sec == nil || sec.Auth == nil { + return status.Error(codes.InvalidArgument, "secret is nil") + } + + if sec.Auth.LeaseDuration == 0 { + logger.Debug("Token will never expire") + return nil + } + if !sec.Auth.Renewable { + logger.Debug("Token is not renewable") + close(renewCh) + return nil + } + renew, err := NewRenew(vc, sec, logger) + if err != nil { + logger.Error("unable to create renew", err) + return err + } + + go func() { + defer close(renewCh) + renew.Run() + }() + + logger.Debug("Token will be renewed") + + return nil +} + +// ConfigureTLS Configures TLS for Vault Client +func (c *ClientConfig) configureTLS(vc *vapi.Config) error { + if vc.HttpClient == nil { + vc.HttpClient = vapi.DefaultConfig().HttpClient + } + clientTLSConfig := vc.HttpClient.Transport.(*http.Transport).TLSClientConfig + + var clientCert tls.Certificate + foundClientCert := false + + switch { + case c.clientParams.ClientCertPath != "" && c.clientParams.ClientKeyPath != "": + c, err := tls.LoadX509KeyPair(c.clientParams.ClientCertPath, c.clientParams.ClientKeyPath) + if err != nil { + return status.Errorf(codes.InvalidArgument, "failed to parse client cert and private-key: %v", err) + } + clientCert = c + foundClientCert = true + case c.clientParams.ClientCertPath != "" || c.clientParams.ClientKeyPath != "": + return status.Error(codes.InvalidArgument, "both client cert and client key are required") + } + + if c.clientParams.CACertPath != "" { + certs, err := pemutil.LoadCertificates(c.clientParams.CACertPath) + if err != nil { + return status.Errorf(codes.InvalidArgument, "failed to load CA certificate: %v", err) + } + pool := x509.NewCertPool() + for _, cert := range certs { + pool.AddCert(cert) + } + clientTLSConfig.RootCAs = pool + } + + if c.clientParams.TLSSKipVerify { + clientTLSConfig.InsecureSkipVerify = true + } + + if foundClientCert { + clientTLSConfig.GetClientCertificate = func(*tls.CertificateRequestInfo) (*tls.Certificate, error) { + return &clientCert, nil + } + } + + return nil +} + +// SetToken wraps vapi.Client.SetToken() +func (c *Client) SetToken(v string) { + c.vaultClient.SetToken(v) +} + +// Auth authenticates to vault server with TLS certificate method +func (c *Client) Auth(path string, body map[string]any) (*vapi.Secret, error) { + c.vaultClient.ClearToken() + secret, err := c.vaultClient.Logical().Write(path, body) + if err != nil { + return nil, status.Errorf(codes.Unauthenticated, "authentication failed %v: %v", path, err) + } + + tokenID, err := secret.TokenID() + if err != nil { + return nil, status.Errorf(codes.Internal, "authentication is successful, but could not get token: %v", err) + } + c.vaultClient.SetToken(tokenID) + return secret, nil +} + +func (c *Client) LookupSelf(token string) (*vapi.Secret, error) { + if token == "" { + return nil, status.Error(codes.InvalidArgument, "token is empty") + } + c.SetToken(token) + + secret, err := c.vaultClient.Logical().Read("auth/token/lookup-self") + if err != nil { + return nil, status.Errorf(codes.Internal, "token lookup failed: %v", err) + } + + id, err := secret.TokenID() + if err != nil { + return nil, status.Errorf(codes.Internal, "unable to get TokenID: %v", err) + } + renewable, err := secret.TokenIsRenewable() + if err != nil { + return nil, status.Errorf(codes.Internal, "unable to determine if token is renewable: %v", err) + } + ttl, err := secret.TokenTTL() + if err != nil { + return nil, status.Errorf(codes.Internal, "unable to get token ttl: %v", err) + } + secret.Auth = &vapi.SecretAuth{ + ClientToken: id, + Renewable: renewable, + LeaseDuration: int(ttl.Seconds()), + // don't care any parameters + } + return secret, nil +} + +type TransitKeyType string + +const ( + TransitKeyType_RSA_2048 TransitKeyType = "rsa-2048" + TransitKeyType_RSA_4096 TransitKeyType = "rsa-4096" + TransitKeyType_ECDSA_P256 TransitKeyType = "ecdsa-p256" + TransitKeyType_ECDSA_P384 TransitKeyType = "ecdsa-p384" +) + +// CreateKey creates a new key in the specified transit secret engine +// See: https://developer.hashicorp.com/vault/api-docs/secret/transit#create-key +func (c *Client) CreateKey(ctx context.Context, spireKeyID string, keyType TransitKeyType) (*vapi.Secret, error) { + arguments := map[string]interface{}{ + "type": keyType, + "exportable": "false", // TODO: Maybe make this configurable + } + + // TODO: Handle errors here such as key already exists + // TODO: Make the transit engine path configurable + return c.vaultClient.Logical().WriteWithContext(ctx, fmt.Sprintf("/transit/keys/%s", spireKeyID), arguments) +} + +func (c *Client) GetKey(ctx context.Context, spireKeyID string) (*vapi.Secret, error) { + // TODO: Handle errors here + // TODO: Make the transit engine path configurable + return c.vaultClient.Logical().ReadWithContext(ctx, fmt.Sprintf("/transit/keys/%s", spireKeyID)) +}