-
Notifications
You must be signed in to change notification settings - Fork 28
Implement mTLS resources and configuration for Target Allocator server #284
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
273c558
5877fef
6808981
1f0f0d9
dfd7ce6
0b8d26c
0cb2e1d
91f3008
397b8dc
00911e4
3487d4a
62f794d
a713e78
dd199c2
08af589
7771eac
f8440f2
223ef59
d08dfb6
d804cd0
d6c5c6d
be9dcde
a48985b
d73025a
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,113 @@ | ||
// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. | ||
// SPDX-License-Identifier: Apache-2.0 | ||
|
||
package config | ||
|
||
import ( | ||
"context" | ||
"crypto/tls" | ||
"crypto/x509" | ||
"fmt" | ||
"os" | ||
"sync" | ||
"time" | ||
|
||
"github.com/fsnotify/fsnotify" | ||
"sigs.k8s.io/controller-runtime/pkg/certwatcher" | ||
) | ||
|
||
type CertAndCAWatcher struct { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. is this same as the fluent-bit one? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The cert watcher exists on the server-side, so this one is not the same as the fluent-bit one. The fluent-bit one is in the CloudWatch Agent: https://github.com/aws/amazon-cloudwatch-agent/blob/799f0c3bdcfc18fcadf3b5e0c1c9ebaa543e18a8/internal/tls/certWatcher.go#L22. |
||
certWatcher *certwatcher.CertWatcher | ||
|
||
caFilePath string | ||
caPool *x509.CertPool | ||
caWatcher *fsnotify.Watcher | ||
|
||
mu sync.RWMutex | ||
} | ||
|
||
func NewCertAndCAWatcher(certPath, keyPath, caPath string) (*CertAndCAWatcher, error) { | ||
certWatcher, err := certwatcher.New(certPath, keyPath) | ||
if err != nil { | ||
return nil, fmt.Errorf("error creating cert watcher: %w", err) | ||
} | ||
|
||
caPool, err := loadCAPool(caPath) | ||
if err != nil { | ||
return nil, fmt.Errorf("error loading CA pool: %w", err) | ||
} | ||
|
||
caWatcher, err := fsnotify.NewWatcher() | ||
if err != nil { | ||
return nil, fmt.Errorf("error creating CA file watcher: %w", err) | ||
} | ||
if err := caWatcher.Add(caPath); err != nil { | ||
return nil, fmt.Errorf("error adding CA file to watcher: %w", err) | ||
} | ||
|
||
return &CertAndCAWatcher{ | ||
certWatcher: certWatcher, | ||
caFilePath: caPath, | ||
caPool: caPool, | ||
caWatcher: caWatcher, | ||
}, nil | ||
} | ||
|
||
func loadCAPool(caPath string) (*x509.CertPool, error) { | ||
caCert, err := os.ReadFile(caPath) | ||
caCertPool := x509.NewCertPool() | ||
if err != nil { | ||
return nil, fmt.Errorf("error reading CA file: %w", err) | ||
} | ||
caCertPool.AppendCertsFromPEM(caCert) | ||
return caCertPool, nil | ||
} | ||
|
||
func (w *CertAndCAWatcher) Start(ctx context.Context) error { | ||
go func() { | ||
_ = w.certWatcher.Start(ctx) | ||
}() | ||
|
||
go w.watchCA(ctx) | ||
|
||
<-ctx.Done() | ||
return nil | ||
} | ||
|
||
func (w *CertAndCAWatcher) watchCA(ctx context.Context) { | ||
for { | ||
select { | ||
case event, ok := <-w.caWatcher.Events: | ||
if !ok { | ||
return | ||
} | ||
if event.Op.Has(fsnotify.Write) || event.Op.Has(fsnotify.Create) || event.Op.Has(fsnotify.Remove) { | ||
newPool, err := loadCAPool(w.caFilePath) | ||
if err != nil { | ||
continue | ||
} | ||
w.mu.Lock() | ||
w.caPool = newPool | ||
w.mu.Unlock() | ||
|
||
// needed incase file removed | ||
if event.Op.Has(fsnotify.Remove) { | ||
time.Sleep(100 * time.Millisecond) | ||
_ = w.caWatcher.Add(w.caFilePath) | ||
} | ||
} | ||
case <-ctx.Done(): | ||
return | ||
} | ||
} | ||
} | ||
|
||
func (w *CertAndCAWatcher) GetCertificate(clientHello *tls.ClientHelloInfo) (*tls.Certificate, error) { | ||
return w.certWatcher.GetCertificate(clientHello) | ||
} | ||
|
||
func (w *CertAndCAWatcher) GetCAPool() *x509.CertPool { | ||
w.mu.RLock() | ||
defer w.mu.RUnlock() | ||
return w.caPool | ||
} |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,140 @@ | ||
// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. | ||
// SPDX-License-Identifier: Apache-2.0 | ||
|
||
package config | ||
|
||
import ( | ||
"bytes" | ||
"context" | ||
"crypto/rand" | ||
"crypto/rsa" | ||
"crypto/x509" | ||
"crypto/x509/pkix" | ||
"encoding/pem" | ||
"math/big" | ||
"os" | ||
"path/filepath" | ||
"testing" | ||
"time" | ||
) | ||
|
||
func generateSelfSignedCertAndKey(commonName string) (certPEM, keyPEM []byte, err error) { | ||
// Generate RSA key | ||
priv, err := rsa.GenerateKey(rand.Reader, 2048) | ||
if err != nil { | ||
return nil, nil, err | ||
} | ||
|
||
// Create a minimal self-signed certificate template | ||
serial, err := rand.Int(rand.Reader, big.NewInt(1<<63-1)) | ||
if err != nil { | ||
return nil, nil, err | ||
} | ||
|
||
template := &x509.Certificate{ | ||
SerialNumber: serial, | ||
Subject: pkix.Name{ | ||
CommonName: commonName, | ||
}, | ||
NotBefore: time.Now().Add(-time.Hour), | ||
NotAfter: time.Now().Add(time.Hour), | ||
|
||
KeyUsage: x509.KeyUsageDigitalSignature | x509.KeyUsageKeyEncipherment | x509.KeyUsageCertSign, | ||
IsCA: true, | ||
BasicConstraintsValid: true, | ||
} | ||
|
||
// Self-sign the certificate | ||
der, err := x509.CreateCertificate(rand.Reader, template, template, &priv.PublicKey, priv) | ||
if err != nil { | ||
return nil, nil, err | ||
} | ||
|
||
// Encode cert + key to PEM | ||
var certBuf, keyBuf bytes.Buffer | ||
err = pem.Encode(&certBuf, &pem.Block{Type: "CERTIFICATE", Bytes: der}) | ||
if err != nil { | ||
return nil, nil, err | ||
} | ||
err = pem.Encode(&keyBuf, &pem.Block{Type: "RSA PRIVATE KEY", Bytes: x509.MarshalPKCS1PrivateKey(priv)}) | ||
if err != nil { | ||
return nil, nil, err | ||
} | ||
|
||
return certBuf.Bytes(), keyBuf.Bytes(), nil | ||
} | ||
|
||
func TestCertAndCAWatcher_UpdatesCA(t *testing.T) { | ||
t.Parallel() | ||
|
||
// Generate a server cert/key for certwatcher | ||
certPEM, keyPEM, err := generateSelfSignedCertAndKey("test-server") | ||
if err != nil { | ||
t.Fatalf("failed to generate server cert/key: %v", err) | ||
} | ||
|
||
// Generate two distinct self-signed certs to represent old CA vs new CA | ||
oldCAPEM, _, err := generateSelfSignedCertAndKey("old-ca") | ||
if err != nil { | ||
t.Fatalf("failed to generate old CA: %v", err) | ||
} | ||
newCAPEM, _, err := generateSelfSignedCertAndKey("new-ca") | ||
if err != nil { | ||
t.Fatalf("failed to generate new CA: %v", err) | ||
} | ||
|
||
// Write all these PEM files into a temp dir | ||
tmpDir := t.TempDir() | ||
|
||
certPath := filepath.Join(tmpDir, "tls.crt") | ||
keyPath := filepath.Join(tmpDir, "tls.key") | ||
caPath := filepath.Join(tmpDir, "ca.crt") | ||
|
||
if err := os.WriteFile(certPath, certPEM, 0600); err != nil { | ||
t.Fatalf("failed to write cert file: %v", err) | ||
} | ||
if err := os.WriteFile(keyPath, keyPEM, 0600); err != nil { | ||
t.Fatalf("failed to write key file: %v", err) | ||
} | ||
if err := os.WriteFile(caPath, oldCAPEM, 0600); err != nil { | ||
t.Fatalf("failed to write initial CA file: %v", err) | ||
} | ||
|
||
// Create the CertAndCAWatcher using our files | ||
watcher, err := NewCertAndCAWatcher(certPath, keyPath, caPath) | ||
if err != nil { | ||
t.Fatalf("failed to create CertAndCAWatcher: %v", err) | ||
} | ||
|
||
// Start the watcher in the background | ||
ctx, cancel := context.WithCancel(context.Background()) | ||
defer cancel() | ||
go func() { | ||
_ = watcher.Start(ctx) | ||
}() | ||
|
||
// Record the initial CA pool pointer | ||
oldPool := watcher.GetCAPool() | ||
if oldPool == nil { | ||
t.Fatal("expected non-nil initial CA pool") | ||
} | ||
|
||
// Overwrite the CA file with newCAPEM, triggering a reload | ||
if err := os.WriteFile(caPath, newCAPEM, 0600); err != nil { | ||
t.Fatalf("failed to write new CA file: %v", err) | ||
} | ||
|
||
// Loop until the watcher updates the CA pool (or times out) | ||
deadline := time.Now().Add(2 * time.Second) | ||
for { | ||
newPool := watcher.GetCAPool() | ||
if newPool != oldPool { | ||
t.Log("CA pool successfully updated.") | ||
return | ||
} | ||
if time.Now().After(deadline) { | ||
t.Fatal("timed out waiting for CA pool to be updated") | ||
} | ||
time.Sleep(100 * time.Millisecond) | ||
} | ||
} |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -6,7 +6,6 @@ package config | |
import ( | ||
"context" | ||
"crypto/tls" | ||
"crypto/x509" | ||
"errors" | ||
"fmt" | ||
"io/fs" | ||
|
@@ -24,23 +23,22 @@ import ( | |
"k8s.io/client-go/tools/clientcmd" | ||
"k8s.io/klog/v2" | ||
ctrl "sigs.k8s.io/controller-runtime" | ||
"sigs.k8s.io/controller-runtime/pkg/certwatcher" | ||
"sigs.k8s.io/controller-runtime/pkg/log/zap" | ||
|
||
tamanifest "github.com/aws/amazon-cloudwatch-agent-operator/internal/manifests/targetallocator" | ||
) | ||
|
||
const ( | ||
DefaultResyncTime = 5 * time.Minute | ||
DefaultConfigFilePath string = "/conf/targetallocator.yaml" | ||
DefaultCRScrapeInterval model.Duration = model.Duration(time.Second * 30) | ||
DefaultAllocationStrategy = "consistent-hashing" | ||
DefaultFilterStrategy = "relabel-config" | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. DefaultFilterStrategy removed? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yeah, it wasn't being used. |
||
DefaultListenAddr = ":8443" | ||
DefaultCertMountPath = tamanifest.TACertMountPath | ||
DefaultTLSKeyPath = DefaultCertMountPath + "/server.key" | ||
DefaultTLSCertPath = DefaultCertMountPath + "/server.crt" | ||
DefaultCABundlePath = "" | ||
DefaultResyncTime = 5 * time.Minute | ||
DefaultConfigFilePath string = "/conf/targetallocator.yaml" | ||
DefaultCRScrapeInterval model.Duration = model.Duration(time.Second * 30) | ||
DefaultAllocationStrategy = "consistent-hashing" | ||
DefaultListenAddr = ":8443" | ||
DefaultCertMountPath = tamanifest.TACertMountPath | ||
DefaultClientCertMountPath = tamanifest.ClientCertMountPath | ||
DefaultTLSKeyPath = DefaultCertMountPath + "/server.key" | ||
DefaultTLSCertPath = DefaultCertMountPath + "/server.crt" | ||
DefaultCABundlePath = DefaultClientCertMountPath + "/tls-ca.crt" | ||
) | ||
|
||
type Config struct { | ||
|
@@ -150,7 +148,6 @@ func LoadFromCLI(target *Config, flagSet *pflag.FlagSet) error { | |
} | ||
|
||
func unmarshal(cfg *Config, configFile string) error { | ||
|
||
yamlFile, err := os.ReadFile(configFile) | ||
if err != nil { | ||
return err | ||
|
@@ -217,31 +214,29 @@ func ValidateConfig(config *Config) error { | |
} | ||
|
||
func (c HTTPSServerConfig) NewTLSConfig(ctx context.Context) (*tls.Config, error) { | ||
tlsConfig := &tls.Config{ | ||
MinVersion: tls.VersionTLS13, | ||
} | ||
|
||
certWatcher, err := certwatcher.New(c.TLSCertFilePath, c.TLSKeyFilePath) | ||
certWatcher, err := NewCertAndCAWatcher(c.TLSCertFilePath, c.TLSKeyFilePath, c.CAFilePath) | ||
if err != nil { | ||
return nil, err | ||
return nil, fmt.Errorf("error creating certwatcher: %w", err) | ||
} | ||
tlsConfig.GetCertificate = certWatcher.GetCertificate | ||
|
||
go func() { | ||
_ = certWatcher.Start(ctx) | ||
}() | ||
|
||
if c.CAFilePath == "" { | ||
return tlsConfig, nil | ||
// Create the TLS config | ||
tlsConfig := &tls.Config{ | ||
MinVersion: tls.VersionTLS13, | ||
GetCertificate: certWatcher.GetCertificate, | ||
ClientCAs: certWatcher.GetCAPool(), | ||
ClientAuth: tls.RequireAndVerifyClientCert, | ||
} | ||
|
||
caCert, err := os.ReadFile(c.CAFilePath) | ||
caCertPool := x509.NewCertPool() | ||
if err != nil { | ||
return nil, err | ||
// Dynamically update the CA pool if needed | ||
tlsConfig.GetConfigForClient = func(clientHello *tls.ClientHelloInfo) (*tls.Config, error) { | ||
newTLSConfig := tlsConfig.Clone() | ||
newTLSConfig.ClientCAs = certWatcher.GetCAPool() | ||
return newTLSConfig, nil | ||
} | ||
caCertPool.AppendCertsFromPEM(caCert) | ||
tlsConfig.ClientCAs = caCertPool | ||
tlsConfig.ClientAuth = tls.RequireAndVerifyClientCert | ||
|
||
return tlsConfig, nil | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: this is a generic cert watcher..we can probably move it outside of target allocator so other parts of the code can use it. Maybe /internal
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah, I can move it out.