Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement mTLS resources and configuration for Target Allocator server #284

Open
wants to merge 21 commits into
base: main
Choose a base branch
from
Open
1 change: 1 addition & 0 deletions Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ FROM golang:1.22 as builder

# set goproxy=direct
ENV GOPROXY direct
ENV GOINSECURE go.opencensus.io
Copy link
Contributor Author

@musa-asad musa-asad Jan 21, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Temporarily needs to be added since their certificate expired, which is breaking our workflow.

Suggested change
ENV GOINSECURE go.opencensus.io


WORKDIR /workspace
# Copy the Go Modules manifests
Expand Down
113 changes: 113 additions & 0 deletions cmd/amazon-cloudwatch-agent-target-allocator/config/certwatcher.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,113 @@
// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
// SPDX-License-Identifier: Apache-2.0

package config

import (
"context"
"crypto/tls"
"crypto/x509"
"fmt"
"os"
"sync"
"time"

"github.com/fsnotify/fsnotify"
"sigs.k8s.io/controller-runtime/pkg/certwatcher"
)

type CertAndCAWatcher struct {
certWatcher *certwatcher.CertWatcher

caFilePath string
caPool *x509.CertPool
caWatcher *fsnotify.Watcher

mu sync.RWMutex
}

func NewCertAndCAWatcher(certPath, keyPath, caPath string) (*CertAndCAWatcher, error) {
certWatcher, err := certwatcher.New(certPath, keyPath)
if err != nil {
return nil, fmt.Errorf("error creating cert watcher: %w", err)
}

caPool, err := loadCAPool(caPath)
if err != nil {
return nil, fmt.Errorf("error loading CA pool: %w", err)
}

caWatcher, err := fsnotify.NewWatcher()
if err != nil {
return nil, fmt.Errorf("error creating CA file watcher: %w", err)
}
if err := caWatcher.Add(caPath); err != nil {
return nil, fmt.Errorf("error adding CA file to watcher: %w", err)
}

return &CertAndCAWatcher{
certWatcher: certWatcher,
caFilePath: caPath,
caPool: caPool,
caWatcher: caWatcher,
}, nil
}

func loadCAPool(caPath string) (*x509.CertPool, error) {
caCert, err := os.ReadFile(caPath)
caCertPool := x509.NewCertPool()
if err != nil {
return nil, fmt.Errorf("error reading CA file: %w", err)
}
caCertPool.AppendCertsFromPEM(caCert)
return caCertPool, nil
}

func (w *CertAndCAWatcher) Start(ctx context.Context) error {
go func() {
_ = w.certWatcher.Start(ctx)
}()

go w.watchCA(ctx)

<-ctx.Done()
return nil
}

func (w *CertAndCAWatcher) watchCA(ctx context.Context) {
for {
select {
case event, ok := <-w.caWatcher.Events:
if !ok {
return
}
if event.Op.Has(fsnotify.Write) || event.Op.Has(fsnotify.Create) || event.Op.Has(fsnotify.Remove) {
newPool, err := loadCAPool(w.caFilePath)
if err != nil {
continue
}
w.mu.Lock()
w.caPool = newPool
w.mu.Unlock()

// needed incase file removed
if event.Op.Has(fsnotify.Remove) {
time.Sleep(100 * time.Millisecond)
_ = w.caWatcher.Add(w.caFilePath)
}
}
case <-ctx.Done():
return
}
}
}

func (w *CertAndCAWatcher) GetCertificate(clientHello *tls.ClientHelloInfo) (*tls.Certificate, error) {
return w.certWatcher.GetCertificate(clientHello)
}

func (w *CertAndCAWatcher) GetCAPool() *x509.CertPool {
w.mu.RLock()
defer w.mu.RUnlock()
return w.caPool
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,140 @@
// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
// SPDX-License-Identifier: Apache-2.0

package config

import (
"bytes"
"context"
"crypto/rand"
"crypto/rsa"
"crypto/x509"
"crypto/x509/pkix"
"encoding/pem"
"math/big"
"os"
"path/filepath"
"testing"
"time"
)

func generateSelfSignedCertAndKey(commonName string) (certPEM, keyPEM []byte, err error) {
// Generate RSA key
priv, err := rsa.GenerateKey(rand.Reader, 2048)
if err != nil {
return nil, nil, err
}

// Create a minimal self-signed certificate template
serial, err := rand.Int(rand.Reader, big.NewInt(1<<63-1))
if err != nil {
return nil, nil, err
}

template := &x509.Certificate{
SerialNumber: serial,
Subject: pkix.Name{
CommonName: commonName,
},
NotBefore: time.Now().Add(-time.Hour),
NotAfter: time.Now().Add(time.Hour),

KeyUsage: x509.KeyUsageDigitalSignature | x509.KeyUsageKeyEncipherment | x509.KeyUsageCertSign,
IsCA: true,
BasicConstraintsValid: true,
}

// Self-sign the certificate
der, err := x509.CreateCertificate(rand.Reader, template, template, &priv.PublicKey, priv)
if err != nil {
return nil, nil, err
}

// Encode cert + key to PEM
var certBuf, keyBuf bytes.Buffer
err = pem.Encode(&certBuf, &pem.Block{Type: "CERTIFICATE", Bytes: der})
if err != nil {
return nil, nil, err
}
err = pem.Encode(&keyBuf, &pem.Block{Type: "RSA PRIVATE KEY", Bytes: x509.MarshalPKCS1PrivateKey(priv)})
if err != nil {
return nil, nil, err
}

return certBuf.Bytes(), keyBuf.Bytes(), nil
}

func TestCertAndCAWatcher_UpdatesCA(t *testing.T) {
t.Parallel()

// Generate a server cert/key for certwatcher
certPEM, keyPEM, err := generateSelfSignedCertAndKey("test-server")
if err != nil {
t.Fatalf("failed to generate server cert/key: %v", err)
}

// Generate two distinct self-signed certs to represent old CA vs new CA
oldCAPEM, _, err := generateSelfSignedCertAndKey("old-ca")
if err != nil {
t.Fatalf("failed to generate old CA: %v", err)
}
newCAPEM, _, err := generateSelfSignedCertAndKey("new-ca")
if err != nil {
t.Fatalf("failed to generate new CA: %v", err)
}

// Write all these PEM files into a temp dir
tmpDir := t.TempDir()

certPath := filepath.Join(tmpDir, "tls.crt")
keyPath := filepath.Join(tmpDir, "tls.key")
caPath := filepath.Join(tmpDir, "ca.crt")

if err := os.WriteFile(certPath, certPEM, 0600); err != nil {
t.Fatalf("failed to write cert file: %v", err)
}
if err := os.WriteFile(keyPath, keyPEM, 0600); err != nil {
t.Fatalf("failed to write key file: %v", err)
}
if err := os.WriteFile(caPath, oldCAPEM, 0600); err != nil {
t.Fatalf("failed to write initial CA file: %v", err)
}

// Create the CertAndCAWatcher using our files
watcher, err := NewCertAndCAWatcher(certPath, keyPath, caPath)
if err != nil {
t.Fatalf("failed to create CertAndCAWatcher: %v", err)
}

// Start the watcher in the background
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
go func() {
_ = watcher.Start(ctx)
}()

// Record the initial CA pool pointer
oldPool := watcher.GetCAPool()
if oldPool == nil {
t.Fatal("expected non-nil initial CA pool")
}

// Overwrite the CA file with newCAPEM, triggering a reload
if err := os.WriteFile(caPath, newCAPEM, 0600); err != nil {
t.Fatalf("failed to write new CA file: %v", err)
}

// Loop until the watcher updates the CA pool (or times out)
deadline := time.Now().Add(2 * time.Second)
for {
newPool := watcher.GetCAPool()
if newPool != oldPool {
t.Log("CA pool successfully updated.")
return
}
if time.Now().After(deadline) {
t.Fatal("timed out waiting for CA pool to be updated")
}
time.Sleep(100 * time.Millisecond)
}
}
60 changes: 24 additions & 36 deletions cmd/amazon-cloudwatch-agent-target-allocator/config/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@ package config
import (
"context"
"crypto/tls"
"crypto/x509"
"errors"
"fmt"
"io/fs"
Expand All @@ -24,23 +23,22 @@ import (
"k8s.io/client-go/tools/clientcmd"
"k8s.io/klog/v2"
ctrl "sigs.k8s.io/controller-runtime"
"sigs.k8s.io/controller-runtime/pkg/certwatcher"
"sigs.k8s.io/controller-runtime/pkg/log/zap"

tamanifest "github.com/aws/amazon-cloudwatch-agent-operator/internal/manifests/targetallocator"
)

const (
DefaultResyncTime = 5 * time.Minute
DefaultConfigFilePath string = "/conf/targetallocator.yaml"
DefaultCRScrapeInterval model.Duration = model.Duration(time.Second * 30)
DefaultAllocationStrategy = "consistent-hashing"
DefaultFilterStrategy = "relabel-config"
DefaultListenAddr = ":8443"
DefaultCertMountPath = tamanifest.TACertMountPath
DefaultTLSKeyPath = DefaultCertMountPath + "/server.key"
DefaultTLSCertPath = DefaultCertMountPath + "/server.crt"
DefaultCABundlePath = ""
DefaultResyncTime = 5 * time.Minute
DefaultConfigFilePath string = "/conf/targetallocator.yaml"
DefaultCRScrapeInterval model.Duration = model.Duration(time.Second * 30)
DefaultAllocationStrategy = "consistent-hashing"
DefaultListenAddr = ":8443"
DefaultCertMountPath = tamanifest.TACertMountPath
DefaultClientCertMountPath = tamanifest.ClientCertMountPath
DefaultTLSKeyPath = DefaultCertMountPath + "/server.key"
DefaultTLSCertPath = DefaultCertMountPath + "/server.crt"
DefaultCABundlePath = DefaultClientCertMountPath + "/tls-ca.crt"
)

type Config struct {
Expand All @@ -66,7 +64,6 @@ type PrometheusCRConfig struct {
}

type HTTPSServerConfig struct {
Enabled bool `yaml:"enabled,omitempty"`
ListenAddr string `yaml:"listen_addr,omitempty"`
CAFilePath string `yaml:"ca_file_path,omitempty"`
TLSCertFilePath string `yaml:"tls_cert_file_path,omitempty"`
Expand Down Expand Up @@ -121,11 +118,6 @@ func LoadFromCLI(target *Config, flagSet *pflag.FlagSet) error {
return err
}

target.HTTPS.Enabled, err = getHttpsEnabled(flagSet)
if err != nil {
return err
}

target.HTTPS.ListenAddr, err = getHttpsListenAddr(flagSet)
if err != nil {
return err
Expand All @@ -150,7 +142,6 @@ func LoadFromCLI(target *Config, flagSet *pflag.FlagSet) error {
}

func unmarshal(cfg *Config, configFile string) error {

yamlFile, err := os.ReadFile(configFile)
if err != nil {
return err
Expand All @@ -169,7 +160,6 @@ func CreateDefaultConfig() Config {
},
AllocationStrategy: &allocation_strategy,
HTTPS: HTTPSServerConfig{
Enabled: true,
ListenAddr: DefaultListenAddr,
CAFilePath: DefaultCABundlePath,
TLSCertFilePath: DefaultTLSCertPath,
Expand Down Expand Up @@ -217,31 +207,29 @@ func ValidateConfig(config *Config) error {
}

func (c HTTPSServerConfig) NewTLSConfig(ctx context.Context) (*tls.Config, error) {
tlsConfig := &tls.Config{
MinVersion: tls.VersionTLS13,
}

certWatcher, err := certwatcher.New(c.TLSCertFilePath, c.TLSKeyFilePath)
certWatcher, err := NewCertAndCAWatcher(c.TLSCertFilePath, c.TLSKeyFilePath, c.CAFilePath)
if err != nil {
return nil, err
return nil, fmt.Errorf("error creating certwatcher: %w", err)
}
tlsConfig.GetCertificate = certWatcher.GetCertificate

go func() {
_ = certWatcher.Start(ctx)
}()

if c.CAFilePath == "" {
return tlsConfig, nil
// Create the TLS config
tlsConfig := &tls.Config{
MinVersion: tls.VersionTLS13,
GetCertificate: certWatcher.GetCertificate,
ClientCAs: certWatcher.GetCAPool(),
ClientAuth: tls.RequireAndVerifyClientCert,
}

caCert, err := os.ReadFile(c.CAFilePath)
caCertPool := x509.NewCertPool()
if err != nil {
return nil, err
// Dynamically update the CA pool if needed
tlsConfig.GetConfigForClient = func(clientHello *tls.ClientHelloInfo) (*tls.Config, error) {
newTLSConfig := tlsConfig.Clone()
newTLSConfig.ClientCAs = certWatcher.GetCAPool()
return newTLSConfig, nil
}
caCertPool.AppendCertsFromPEM(caCert)
tlsConfig.ClientCAs = caCertPool
tlsConfig.ClientAuth = tls.RequireAndVerifyClientCert

return tlsConfig, nil
}
Loading
Loading