Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement mTLS resources and configuration for Target Allocator server #284

Open
wants to merge 21 commits into
base: main
Choose a base branch
from
Open
1 change: 1 addition & 0 deletions Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ FROM golang:1.22 as builder

# set goproxy=direct
ENV GOPROXY direct
ENV GOINSECURE go.opencensus.io
Copy link
Contributor Author

@musa-asad musa-asad Jan 21, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Temporarily needs to be added since their certificate expired, which is breaking our workflow.

Suggested change
ENV GOINSECURE go.opencensus.io


WORKDIR /workspace
# Copy the Go Modules manifests
Expand Down
113 changes: 113 additions & 0 deletions cmd/amazon-cloudwatch-agent-target-allocator/config/certwatcher.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,113 @@
// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
// SPDX-License-Identifier: Apache-2.0

package config

import (
"context"
"crypto/tls"
"crypto/x509"
"fmt"
"os"
"sync"
"time"

"github.com/fsnotify/fsnotify"
"sigs.k8s.io/controller-runtime/pkg/certwatcher"
)

type CertAndCAWatcher struct {
certWatcher *certwatcher.CertWatcher

caFilePath string
caPool *x509.CertPool
caWatcher *fsnotify.Watcher

mu sync.RWMutex
}

func NewCertAndCAWatcher(certPath, keyPath, caPath string) (*CertAndCAWatcher, error) {
certWatcher, err := certwatcher.New(certPath, keyPath)
if err != nil {
return nil, fmt.Errorf("error creating cert watcher: %w", err)
}

caPool, err := loadCAPool(caPath)
if err != nil {
return nil, fmt.Errorf("error loading CA pool: %w", err)
}

caWatcher, err := fsnotify.NewWatcher()
if err != nil {
return nil, fmt.Errorf("error creating CA file watcher: %w", err)
}
if err := caWatcher.Add(caPath); err != nil {
return nil, fmt.Errorf("error adding CA file to watcher: %w", err)
}

return &CertAndCAWatcher{
certWatcher: certWatcher,
caFilePath: caPath,
caPool: caPool,
caWatcher: caWatcher,
}, nil
}

func loadCAPool(caPath string) (*x509.CertPool, error) {
caCert, err := os.ReadFile(caPath)
caCertPool := x509.NewCertPool()
if err != nil {
return nil, fmt.Errorf("error reading CA file: %w", err)
}
caCertPool.AppendCertsFromPEM(caCert)
return caCertPool, nil
}

func (w *CertAndCAWatcher) Start(ctx context.Context) error {
go func() {
_ = w.certWatcher.Start(ctx)
}()

go w.watchCA(ctx)

<-ctx.Done()
return nil
}

func (w *CertAndCAWatcher) watchCA(ctx context.Context) {
for {
select {
case event, ok := <-w.caWatcher.Events:
if !ok {
return
}
if event.Op.Has(fsnotify.Write) || event.Op.Has(fsnotify.Create) || event.Op.Has(fsnotify.Remove) {
newPool, err := loadCAPool(w.caFilePath)
if err != nil {
continue
}
w.mu.Lock()
w.caPool = newPool
w.mu.Unlock()

// needed incase file removed
if event.Op.Has(fsnotify.Remove) {
time.Sleep(100 * time.Millisecond)
_ = w.caWatcher.Add(w.caFilePath)
}
}
case <-ctx.Done():
return
}
}
}

func (w *CertAndCAWatcher) GetCertificate(clientHello *tls.ClientHelloInfo) (*tls.Certificate, error) {
return w.certWatcher.GetCertificate(clientHello)
}

func (w *CertAndCAWatcher) GetCAPool() *x509.CertPool {
w.mu.RLock()
defer w.mu.RUnlock()
return w.caPool
}
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
package config
60 changes: 24 additions & 36 deletions cmd/amazon-cloudwatch-agent-target-allocator/config/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@ package config
import (
"context"
"crypto/tls"
"crypto/x509"
"errors"
"fmt"
"io/fs"
Expand All @@ -24,23 +23,22 @@ import (
"k8s.io/client-go/tools/clientcmd"
"k8s.io/klog/v2"
ctrl "sigs.k8s.io/controller-runtime"
"sigs.k8s.io/controller-runtime/pkg/certwatcher"
"sigs.k8s.io/controller-runtime/pkg/log/zap"

tamanifest "github.com/aws/amazon-cloudwatch-agent-operator/internal/manifests/targetallocator"
)

const (
DefaultResyncTime = 5 * time.Minute
DefaultConfigFilePath string = "/conf/targetallocator.yaml"
DefaultCRScrapeInterval model.Duration = model.Duration(time.Second * 30)
DefaultAllocationStrategy = "consistent-hashing"
DefaultFilterStrategy = "relabel-config"
DefaultListenAddr = ":8443"
DefaultCertMountPath = tamanifest.TACertMountPath
DefaultTLSKeyPath = DefaultCertMountPath + "/server.key"
DefaultTLSCertPath = DefaultCertMountPath + "/server.crt"
DefaultCABundlePath = ""
DefaultResyncTime = 5 * time.Minute
DefaultConfigFilePath string = "/conf/targetallocator.yaml"
DefaultCRScrapeInterval model.Duration = model.Duration(time.Second * 30)
DefaultAllocationStrategy = "consistent-hashing"
DefaultListenAddr = ":8443"
DefaultCertMountPath = tamanifest.TACertMountPath
DefaultClientCertMountPath = tamanifest.ClientCertMountPath
DefaultTLSKeyPath = DefaultCertMountPath + "/server.key"
DefaultTLSCertPath = DefaultCertMountPath + "/server.crt"
DefaultCABundlePath = DefaultClientCertMountPath + "/tls-ca.crt"
)

type Config struct {
Expand All @@ -66,7 +64,6 @@ type PrometheusCRConfig struct {
}

type HTTPSServerConfig struct {
Enabled bool `yaml:"enabled,omitempty"`
ListenAddr string `yaml:"listen_addr,omitempty"`
CAFilePath string `yaml:"ca_file_path,omitempty"`
TLSCertFilePath string `yaml:"tls_cert_file_path,omitempty"`
Expand Down Expand Up @@ -121,11 +118,6 @@ func LoadFromCLI(target *Config, flagSet *pflag.FlagSet) error {
return err
}

target.HTTPS.Enabled, err = getHttpsEnabled(flagSet)
if err != nil {
return err
}

target.HTTPS.ListenAddr, err = getHttpsListenAddr(flagSet)
if err != nil {
return err
Expand All @@ -150,7 +142,6 @@ func LoadFromCLI(target *Config, flagSet *pflag.FlagSet) error {
}

func unmarshal(cfg *Config, configFile string) error {

yamlFile, err := os.ReadFile(configFile)
if err != nil {
return err
Expand All @@ -169,7 +160,6 @@ func CreateDefaultConfig() Config {
},
AllocationStrategy: &allocation_strategy,
HTTPS: HTTPSServerConfig{
Enabled: true,
ListenAddr: DefaultListenAddr,
CAFilePath: DefaultCABundlePath,
TLSCertFilePath: DefaultTLSCertPath,
Expand Down Expand Up @@ -217,31 +207,29 @@ func ValidateConfig(config *Config) error {
}

func (c HTTPSServerConfig) NewTLSConfig(ctx context.Context) (*tls.Config, error) {
tlsConfig := &tls.Config{
MinVersion: tls.VersionTLS13,
}

certWatcher, err := certwatcher.New(c.TLSCertFilePath, c.TLSKeyFilePath)
certWatcher, err := NewCertAndCAWatcher(c.TLSCertFilePath, c.TLSKeyFilePath, c.CAFilePath)
if err != nil {
return nil, err
return nil, fmt.Errorf("error creating certwatcher: %w", err)
}
tlsConfig.GetCertificate = certWatcher.GetCertificate

go func() {
_ = certWatcher.Start(ctx)
}()

if c.CAFilePath == "" {
return tlsConfig, nil
// Create the TLS config
tlsConfig := &tls.Config{
MinVersion: tls.VersionTLS13,
GetCertificate: certWatcher.GetCertificate,
ClientCAs: certWatcher.GetCAPool(),
ClientAuth: tls.RequireAndVerifyClientCert,
}

caCert, err := os.ReadFile(c.CAFilePath)
caCertPool := x509.NewCertPool()
if err != nil {
return nil, err
// Dynamically update the CA pool if needed
tlsConfig.GetConfigForClient = func(clientHello *tls.ClientHelloInfo) (*tls.Config, error) {
newTLSConfig := tlsConfig.Clone()
newTLSConfig.ClientCAs = certWatcher.GetCAPool()
return newTLSConfig, nil
}
caCertPool.AppendCertsFromPEM(caCert)
tlsConfig.ClientCAs = caCertPool
tlsConfig.ClientAuth = tls.RequireAndVerifyClientCert

return tlsConfig, nil
}
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,6 @@ func TestLoad(t *testing.T) {
ScrapeInterval: model.Duration(time.Second * 60),
},
HTTPS: HTTPSServerConfig{
Enabled: true,
ListenAddr: DefaultListenAddr,
CAFilePath: "/path/to/ca.pem",
TLSCertFilePath: "/path/to/cert.pem",
Expand Down Expand Up @@ -114,7 +113,6 @@ func TestLoad(t *testing.T) {
ScrapeInterval: DefaultCRScrapeInterval,
},
HTTPS: HTTPSServerConfig{
Enabled: true,
ListenAddr: DefaultListenAddr,
CAFilePath: DefaultCABundlePath,
TLSCertFilePath: DefaultTLSCertPath,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@ label_selector:
prometheus_cr:
scrape_interval: 60s
https:
enabled: true
ca_file_path: /path/to/ca.pem
tls_cert_file_path: /path/to/cert.pem
tls_key_file_path: /path/to/key.pem
Expand Down
Loading
Loading