Skip to content

Implement mTLS resources and configuration for Target Allocator server #284

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 24 commits into from
Jun 13, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
113 changes: 113 additions & 0 deletions cmd/amazon-cloudwatch-agent-target-allocator/config/certwatcher.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,113 @@
// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
// SPDX-License-Identifier: Apache-2.0

package config
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: this is a generic cert watcher..we can probably move it outside of target allocator so other parts of the code can use it. Maybe /internal

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, I can move it out.


import (
"context"
"crypto/tls"
"crypto/x509"
"fmt"
"os"
"sync"
"time"

"github.com/fsnotify/fsnotify"
"sigs.k8s.io/controller-runtime/pkg/certwatcher"
)

type CertAndCAWatcher struct {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is this same as the fluent-bit one?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The cert watcher exists on the server-side, so this one is not the same as the fluent-bit one. The fluent-bit one is in the CloudWatch Agent: https://github.com/aws/amazon-cloudwatch-agent/blob/799f0c3bdcfc18fcadf3b5e0c1c9ebaa543e18a8/internal/tls/certWatcher.go#L22.

certWatcher *certwatcher.CertWatcher

caFilePath string
caPool *x509.CertPool
caWatcher *fsnotify.Watcher

mu sync.RWMutex
}

func NewCertAndCAWatcher(certPath, keyPath, caPath string) (*CertAndCAWatcher, error) {
certWatcher, err := certwatcher.New(certPath, keyPath)
if err != nil {
return nil, fmt.Errorf("error creating cert watcher: %w", err)
}

caPool, err := loadCAPool(caPath)
if err != nil {
return nil, fmt.Errorf("error loading CA pool: %w", err)
}

caWatcher, err := fsnotify.NewWatcher()
if err != nil {
return nil, fmt.Errorf("error creating CA file watcher: %w", err)
}
if err := caWatcher.Add(caPath); err != nil {
return nil, fmt.Errorf("error adding CA file to watcher: %w", err)
}

return &CertAndCAWatcher{
certWatcher: certWatcher,
caFilePath: caPath,
caPool: caPool,
caWatcher: caWatcher,
}, nil
}

func loadCAPool(caPath string) (*x509.CertPool, error) {
caCert, err := os.ReadFile(caPath)
caCertPool := x509.NewCertPool()
if err != nil {
return nil, fmt.Errorf("error reading CA file: %w", err)
}
caCertPool.AppendCertsFromPEM(caCert)
return caCertPool, nil
}

func (w *CertAndCAWatcher) Start(ctx context.Context) error {
go func() {
_ = w.certWatcher.Start(ctx)
}()

go w.watchCA(ctx)

<-ctx.Done()
return nil
}

func (w *CertAndCAWatcher) watchCA(ctx context.Context) {
for {
select {
case event, ok := <-w.caWatcher.Events:
if !ok {
return
}
if event.Op.Has(fsnotify.Write) || event.Op.Has(fsnotify.Create) || event.Op.Has(fsnotify.Remove) {
newPool, err := loadCAPool(w.caFilePath)
if err != nil {
continue
}
w.mu.Lock()
w.caPool = newPool
w.mu.Unlock()

// needed incase file removed
if event.Op.Has(fsnotify.Remove) {
time.Sleep(100 * time.Millisecond)
_ = w.caWatcher.Add(w.caFilePath)
}
}
case <-ctx.Done():
return
}
}
}

func (w *CertAndCAWatcher) GetCertificate(clientHello *tls.ClientHelloInfo) (*tls.Certificate, error) {
return w.certWatcher.GetCertificate(clientHello)
}

func (w *CertAndCAWatcher) GetCAPool() *x509.CertPool {
w.mu.RLock()
defer w.mu.RUnlock()
return w.caPool
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,140 @@
// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
// SPDX-License-Identifier: Apache-2.0

package config

import (
"bytes"
"context"
"crypto/rand"
"crypto/rsa"
"crypto/x509"
"crypto/x509/pkix"
"encoding/pem"
"math/big"
"os"
"path/filepath"
"testing"
"time"
)

func generateSelfSignedCertAndKey(commonName string) (certPEM, keyPEM []byte, err error) {
// Generate RSA key
priv, err := rsa.GenerateKey(rand.Reader, 2048)
if err != nil {
return nil, nil, err
}

// Create a minimal self-signed certificate template
serial, err := rand.Int(rand.Reader, big.NewInt(1<<63-1))
if err != nil {
return nil, nil, err
}

template := &x509.Certificate{
SerialNumber: serial,
Subject: pkix.Name{
CommonName: commonName,
},
NotBefore: time.Now().Add(-time.Hour),
NotAfter: time.Now().Add(time.Hour),

KeyUsage: x509.KeyUsageDigitalSignature | x509.KeyUsageKeyEncipherment | x509.KeyUsageCertSign,
IsCA: true,
BasicConstraintsValid: true,
}

// Self-sign the certificate
der, err := x509.CreateCertificate(rand.Reader, template, template, &priv.PublicKey, priv)
if err != nil {
return nil, nil, err
}

// Encode cert + key to PEM
var certBuf, keyBuf bytes.Buffer
err = pem.Encode(&certBuf, &pem.Block{Type: "CERTIFICATE", Bytes: der})
if err != nil {
return nil, nil, err
}
err = pem.Encode(&keyBuf, &pem.Block{Type: "RSA PRIVATE KEY", Bytes: x509.MarshalPKCS1PrivateKey(priv)})
if err != nil {
return nil, nil, err
}

return certBuf.Bytes(), keyBuf.Bytes(), nil
}

func TestCertAndCAWatcher_UpdatesCA(t *testing.T) {
t.Parallel()

// Generate a server cert/key for certwatcher
certPEM, keyPEM, err := generateSelfSignedCertAndKey("test-server")
if err != nil {
t.Fatalf("failed to generate server cert/key: %v", err)
}

// Generate two distinct self-signed certs to represent old CA vs new CA
oldCAPEM, _, err := generateSelfSignedCertAndKey("old-ca")
if err != nil {
t.Fatalf("failed to generate old CA: %v", err)
}
newCAPEM, _, err := generateSelfSignedCertAndKey("new-ca")
if err != nil {
t.Fatalf("failed to generate new CA: %v", err)
}

// Write all these PEM files into a temp dir
tmpDir := t.TempDir()

certPath := filepath.Join(tmpDir, "tls.crt")
keyPath := filepath.Join(tmpDir, "tls.key")
caPath := filepath.Join(tmpDir, "ca.crt")

if err := os.WriteFile(certPath, certPEM, 0600); err != nil {
t.Fatalf("failed to write cert file: %v", err)
}
if err := os.WriteFile(keyPath, keyPEM, 0600); err != nil {
t.Fatalf("failed to write key file: %v", err)
}
if err := os.WriteFile(caPath, oldCAPEM, 0600); err != nil {
t.Fatalf("failed to write initial CA file: %v", err)
}

// Create the CertAndCAWatcher using our files
watcher, err := NewCertAndCAWatcher(certPath, keyPath, caPath)
if err != nil {
t.Fatalf("failed to create CertAndCAWatcher: %v", err)
}

// Start the watcher in the background
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
go func() {
_ = watcher.Start(ctx)
}()

// Record the initial CA pool pointer
oldPool := watcher.GetCAPool()
if oldPool == nil {
t.Fatal("expected non-nil initial CA pool")
}

// Overwrite the CA file with newCAPEM, triggering a reload
if err := os.WriteFile(caPath, newCAPEM, 0600); err != nil {
t.Fatalf("failed to write new CA file: %v", err)
}

// Loop until the watcher updates the CA pool (or times out)
deadline := time.Now().Add(2 * time.Second)
for {
newPool := watcher.GetCAPool()
if newPool != oldPool {
t.Log("CA pool successfully updated.")
return
}
if time.Now().After(deadline) {
t.Fatal("timed out waiting for CA pool to be updated")
}
time.Sleep(100 * time.Millisecond)
}
}
53 changes: 24 additions & 29 deletions cmd/amazon-cloudwatch-agent-target-allocator/config/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@ package config
import (
"context"
"crypto/tls"
"crypto/x509"
"errors"
"fmt"
"io/fs"
Expand All @@ -24,23 +23,22 @@ import (
"k8s.io/client-go/tools/clientcmd"
"k8s.io/klog/v2"
ctrl "sigs.k8s.io/controller-runtime"
"sigs.k8s.io/controller-runtime/pkg/certwatcher"
"sigs.k8s.io/controller-runtime/pkg/log/zap"

tamanifest "github.com/aws/amazon-cloudwatch-agent-operator/internal/manifests/targetallocator"
)

const (
DefaultResyncTime = 5 * time.Minute
DefaultConfigFilePath string = "/conf/targetallocator.yaml"
DefaultCRScrapeInterval model.Duration = model.Duration(time.Second * 30)
DefaultAllocationStrategy = "consistent-hashing"
DefaultFilterStrategy = "relabel-config"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

DefaultFilterStrategy removed?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, it wasn't being used.

DefaultListenAddr = ":8443"
DefaultCertMountPath = tamanifest.TACertMountPath
DefaultTLSKeyPath = DefaultCertMountPath + "/server.key"
DefaultTLSCertPath = DefaultCertMountPath + "/server.crt"
DefaultCABundlePath = ""
DefaultResyncTime = 5 * time.Minute
DefaultConfigFilePath string = "/conf/targetallocator.yaml"
DefaultCRScrapeInterval model.Duration = model.Duration(time.Second * 30)
DefaultAllocationStrategy = "consistent-hashing"
DefaultListenAddr = ":8443"
DefaultCertMountPath = tamanifest.TACertMountPath
DefaultClientCertMountPath = tamanifest.ClientCertMountPath
DefaultTLSKeyPath = DefaultCertMountPath + "/server.key"
DefaultTLSCertPath = DefaultCertMountPath + "/server.crt"
DefaultCABundlePath = DefaultClientCertMountPath + "/tls-ca.crt"
)

type Config struct {
Expand Down Expand Up @@ -150,7 +148,6 @@ func LoadFromCLI(target *Config, flagSet *pflag.FlagSet) error {
}

func unmarshal(cfg *Config, configFile string) error {

yamlFile, err := os.ReadFile(configFile)
if err != nil {
return err
Expand Down Expand Up @@ -217,31 +214,29 @@ func ValidateConfig(config *Config) error {
}

func (c HTTPSServerConfig) NewTLSConfig(ctx context.Context) (*tls.Config, error) {
tlsConfig := &tls.Config{
MinVersion: tls.VersionTLS13,
}

certWatcher, err := certwatcher.New(c.TLSCertFilePath, c.TLSKeyFilePath)
certWatcher, err := NewCertAndCAWatcher(c.TLSCertFilePath, c.TLSKeyFilePath, c.CAFilePath)
if err != nil {
return nil, err
return nil, fmt.Errorf("error creating certwatcher: %w", err)
}
tlsConfig.GetCertificate = certWatcher.GetCertificate

go func() {
_ = certWatcher.Start(ctx)
}()

if c.CAFilePath == "" {
return tlsConfig, nil
// Create the TLS config
tlsConfig := &tls.Config{
MinVersion: tls.VersionTLS13,
GetCertificate: certWatcher.GetCertificate,
ClientCAs: certWatcher.GetCAPool(),
ClientAuth: tls.RequireAndVerifyClientCert,
}

caCert, err := os.ReadFile(c.CAFilePath)
caCertPool := x509.NewCertPool()
if err != nil {
return nil, err
// Dynamically update the CA pool if needed
tlsConfig.GetConfigForClient = func(clientHello *tls.ClientHelloInfo) (*tls.Config, error) {
newTLSConfig := tlsConfig.Clone()
newTLSConfig.ClientCAs = certWatcher.GetCAPool()
return newTLSConfig, nil
}
caCertPool.AppendCertsFromPEM(caCert)
tlsConfig.ClientCAs = caCertPool
tlsConfig.ClientAuth = tls.RequireAndVerifyClientCert

return tlsConfig, nil
}
Loading
Loading