Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: improve mTLS test #131

Merged
merged 1 commit into from
Jul 27, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
126 changes: 96 additions & 30 deletions http/http_tls_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -37,47 +37,119 @@ func TestTLSConfig(t *testing.T) {
t.Fatal(err)
}

port := "18080"
server, err := tlsServer(*serverCrt, port)
_, _, badClientPEM, badClientKeyPem, err := createCert(nil, nil, "bad-client")
if err != nil {
t.Fatal(err)
}
logger.Infof("Listening on port %s", port)

serverReady := make(chan struct{})
go func() {
close(serverReady)
err := server.ListenAndServeTLS("", "")
logger.Infof("server error: %v", err)
}()

serverTerminate := make(chan struct{})
go func() {
<-serverTerminate
_ = server.Shutdown(context.Background())
}()
certPool := x509.NewCertPool()
if !certPool.AppendCertsFromPEM(caPEM) {
t.Fatal(err)
}

<-serverReady
testData := []struct {
name string
tlsConfig chttp.TLSConfig
expectErr bool
serverTLS *tls.Config
clientTLS chttp.TLSConfig
}{
{"withca", chttp.TLSConfig{CA: string(caPEM)}},
{"with client certs and CA", chttp.TLSConfig{Cert: string(clientPEM), Key: string(clientKeyPem), CA: string(caPEM)}}, // FIXME: Setup an HTTPs server that requires client auth
{
name: "Client provides CA",
clientTLS: chttp.TLSConfig{CA: string(caPEM)},
serverTLS: &tls.Config{
Certificates: []tls.Certificate{*serverCrt},
},
},
{
name: "Client doesn't provide CA",
clientTLS: chttp.TLSConfig{},
expectErr: true,
serverTLS: &tls.Config{
Certificates: []tls.Certificate{*serverCrt},
},
},
{
name: "mTLS | client provides client certs",
clientTLS: chttp.TLSConfig{
Cert: string(clientPEM),
Key: string(clientKeyPem),
CA: string(caPEM),
},
serverTLS: &tls.Config{
Certificates: []tls.Certificate{*serverCrt},
ClientCAs: certPool,
ClientAuth: tls.RequireAndVerifyClientCert,
},
},
{
name: "mTLS | client doesn't provides client certs",
clientTLS: chttp.TLSConfig{
CA: string(caPEM),
},
expectErr: true,
serverTLS: &tls.Config{
Certificates: []tls.Certificate{*serverCrt},
ClientCAs: certPool,
ClientAuth: tls.RequireAndVerifyClientCert,
},
},
{
name: "mTLS | client provides bad certs",
expectErr: true,
clientTLS: chttp.TLSConfig{
CA: string(caPEM),
Cert: string(badClientPEM),
Key: string(badClientKeyPem),
},
serverTLS: &tls.Config{
Certificates: []tls.Certificate{*serverCrt},
ClientCAs: certPool,
ClientAuth: tls.RequireAndVerifyClientCert,
},
},
}

for _, td := range testData {
t.Run(td.name, func(t *testing.T) {
client, err := chttp.NewClient().TLSConfig(td.tlsConfig)
port := "18080"
server, err := tlsServer(port, td.serverTLS)
if err != nil {
t.Fatal(err)
}

response, err := client.BaseURL(fmt.Sprintf("https://localhost:%s", port)).R(context.Background()).Get("/")
serverReady := make(chan struct{})
go func() {
close(serverReady)
err := server.ListenAndServeTLS("", "")
logger.Infof("server error: %v", err)
}()

serverTerminate := make(chan struct{})
go func() {
<-serverTerminate
_ = server.Shutdown(context.Background())
}()

<-serverReady
defer func() { serverTerminate <- struct{}{} }()

client, err := chttp.NewClient().TLSConfig(td.clientTLS)
if err != nil {
t.Fatal(err)
}

response, err := client.BaseURL(fmt.Sprintf("https://localhost:%s", port)).R(context.Background()).Get("/")
if err != nil {
if !td.expectErr {
t.Fatal(err)
}
return
} else {
if td.expectErr {
t.Fatal("expected error")
}
}

r, err := response.AsString()
if err != nil {
t.Fatal(err)
Expand All @@ -87,19 +159,15 @@ func TestTLSConfig(t *testing.T) {
}
})
}

serverTerminate <- struct{}{}
}

func tlsServer(cert tls.Certificate, port string) (*http.Server, error) {
func tlsServer(port string, tlsConfig *tls.Config) (*http.Server, error) {
server := &http.Server{
Addr: fmt.Sprintf(":%s", port),
Handler: http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
_, _ = w.Write([]byte("Hello, World!"))
}),
TLSConfig: &tls.Config{
Certificates: []tls.Certificate{cert},
},
TLSConfig: tlsConfig,
}
return server, nil
}
Expand All @@ -118,7 +186,7 @@ func createCert(parent *x509.Certificate, signerKey any, cn string) (*x509.Certi
},
DNSNames: []string{cn},
NotBefore: time.Now(),
NotAfter: time.Now().Add(365 * 24 * time.Hour), // Valid for 1 year
NotAfter: time.Now().Add(365 * 24 * time.Hour),
IsCA: isCa,
ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageClientAuth, x509.ExtKeyUsageServerAuth},
BasicConstraintsValid: true,
Expand Down Expand Up @@ -153,13 +221,11 @@ func createCert(parent *x509.Certificate, signerKey any, cn string) (*x509.Certi
}
pemBytes := pem.EncodeToMemory(pemBlock)

// Create tls.Certificate
keyBytes, err := x509.MarshalECPrivateKey(privateKey)
if err != nil {
return nil, nil, nil, nil, err
}

// Create a new private key PEM
pemBlock = &pem.Block{
Type: "EC PRIVATE KEY",
Bytes: keyBytes,
Expand Down
Loading