Skip to content

Commit

Permalink
tls config
Browse files Browse the repository at this point in the history
  • Loading branch information
adityathebe committed Jun 13, 2024
1 parent fa1b95f commit a725866
Show file tree
Hide file tree
Showing 3 changed files with 173 additions and 5 deletions.
2 changes: 1 addition & 1 deletion go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ require (
github.com/onsi/gomega v1.27.6
github.com/patrickmn/go-cache v2.1.0+incompatible
github.com/pkg/errors v0.9.1
github.com/samber/lo v1.39.0
github.com/sirupsen/logrus v1.9.3
github.com/spf13/pflag v1.0.5
github.com/stretchr/testify v1.8.4
Expand Down Expand Up @@ -76,7 +77,6 @@ require (
github.com/pmezard/go-difflib v1.0.0 // indirect
github.com/robertkrimen/otto v0.2.1 // indirect
github.com/rogpeppe/go-internal v1.12.0 // indirect
github.com/samber/lo v1.39.0 // indirect
github.com/stoewer/go-strcase v1.3.0 // indirect
github.com/ugorji/go/codec v1.2.11 // indirect
github.com/vadimi/go-ntlm v1.2.1 // indirect
Expand Down
52 changes: 48 additions & 4 deletions http/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package http
import (
"context"
"crypto/tls"
"crypto/x509"
"fmt"
"net/http"
"net/url"
Expand Down Expand Up @@ -175,19 +176,57 @@ func (c *Client) DisableKeepAlive(val bool) *Client {
return c
}

// InsecureSkipVerify controls whether a client verifies the server's
// certificate chain and host name
func (c *Client) InsecureSkipVerify(val bool) *Client {
func (c *Client) initTLSConfig() {
if c.httpClient.Transport == nil {
c.httpClient.Transport = http.DefaultTransport
}

customTransport := c.httpClient.Transport.(*http.Transport).Clone()

if customTransport.TLSClientConfig == nil {
customTransport.TLSClientConfig = &tls.Config{}
}
}

func (c *Client) TLSConfig(conf TLSConfig) (*Client, error) {
c.initTLSConfig()

customTransport := c.httpClient.Transport.(*http.Transport).Clone()

if conf.CA != "" {
certPool, err := x509.SystemCertPool()
if err != nil {
return nil, err
}

if !certPool.AppendCertsFromPEM([]byte(conf.CA)) {
return nil, fmt.Errorf("failed to append ca certificate")
}
customTransport.TLSClientConfig.RootCAs = certPool
}

if conf.Cert != "" {
certPool, err := x509.SystemCertPool()
if err != nil {
return nil, err
}

if !certPool.AppendCertsFromPEM([]byte(conf.Cert)) {
return nil, fmt.Errorf("failed to append client Certificate certificate")
}
customTransport.TLSClientConfig.RootCAs = certPool
}

c.tlsConfig = customTransport.TLSClientConfig
c.httpClient.Transport = customTransport
return c, nil
}

// InsecureSkipVerify controls whether a client verifies the server's
// certificate chain and host name
func (c *Client) InsecureSkipVerify(val bool) *Client {
c.initTLSConfig()

customTransport := c.httpClient.Transport.(*http.Transport).Clone()
customTransport.TLSClientConfig.InsecureSkipVerify = val
c.tlsConfig = customTransport.TLSClientConfig
c.httpClient.Transport = customTransport
Expand Down Expand Up @@ -220,6 +259,11 @@ func (c *Client) Auth(username, password string) *Client {
return c
}

type TLSConfig struct {
CA string
Cert string
}

func (c *Client) OAuth(config middlewares.OauthConfig) *Client {
c.Use(middlewares.NewOauthTransport(config).RoundTripper)
return c
Expand Down
124 changes: 124 additions & 0 deletions http/http_tls_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,124 @@
package http_test

import (
"context"
"crypto/rand"
"crypto/rsa"
"crypto/tls"
"crypto/x509"
"crypto/x509/pkix"
"encoding/pem"
"fmt"
"math"
"math/big"
"net/http"
"sync"
"testing"
"time"

chttp "github.com/flanksource/commons/http"
"github.com/flanksource/commons/logger"
)

func TestTLSConfig(t *testing.T) {
// Generate a self-signed certificate
certPemData, cert, err := generateSelfSignedCert()
if err != nil {
t.Fatal(err)
}

port := "18080"
server, err := tlsServer(cert, port)
if err != nil {
t.Fatal(err)
}
logger.Infof("Listening on port %s", port)

var wg sync.WaitGroup
wg.Add(1)
go func() {
err := server.ListenAndServeTLS("", "")
logger.Infof("server error: %v", err)
wg.Done()
}()

go func() {
time.Sleep(time.Second)
server.Shutdown(context.Background())

Check failure on line 47 in http/http_tls_test.go

View workflow job for this annotation

GitHub Actions / lint

Error return value of `server.Shutdown` is not checked (errcheck)
}()

client, err := chttp.NewClient().TLSConfig(chttp.TLSConfig{Cert: string(certPemData)})
if err != nil {
t.Fatal(err)
}

response, err := client.BaseURL(fmt.Sprintf("https://localhost:%s", port)).R(context.Background()).Get("/")
if err != nil {
t.Fatal(err)
}

r, err := response.AsString()
if err != nil {
t.Fatal(err)
}
if r != "Hello, World!" {
t.Fatal(r)
}

wg.Wait()
}

func tlsServer(cert tls.Certificate, port string) (*http.Server, error) {
server := &http.Server{
Addr: fmt.Sprintf(":%s", port),
Handler: http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Write([]byte("Hello, World!"))

Check failure on line 75 in http/http_tls_test.go

View workflow job for this annotation

GitHub Actions / lint

Error return value of `w.Write` is not checked (errcheck)
}),
TLSConfig: &tls.Config{
Certificates: []tls.Certificate{cert},
},
}
return server, nil
}

func generateSelfSignedCert() ([]byte, tls.Certificate, error) {
subject := pkix.Name{
Organization: []string{"Example Company"},
}

privateKey, err := rsa.GenerateKey(rand.Reader, 2048)
if err != nil {
return nil, tls.Certificate{}, err
}

serialNumber, err := rand.Int(rand.Reader, big.NewInt(math.MaxInt64))
if err != nil {
return nil, tls.Certificate{}, err
}

certTemplate := x509.Certificate{
SerialNumber: serialNumber,
Subject: subject,
NotBefore: time.Now(),
DNSNames: []string{"localhost"},
NotAfter: time.Now().Add(365 * 24 * time.Hour), // Valid for 1 year
KeyUsage: x509.KeyUsageKeyEncipherment | x509.KeyUsageDigitalSignature,
ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth},
IsCA: true,
}

derBytes, err := x509.CreateCertificate(rand.Reader, &certTemplate, &certTemplate, &privateKey.PublicKey, privateKey)
if err != nil {
return nil, tls.Certificate{}, err
}

certPEMData := pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: derBytes})
keyPEMData := pem.EncodeToMemory(&pem.Block{Type: "RSA PRIVATE KEY", Bytes: x509.MarshalPKCS1PrivateKey(privateKey)})

cert, err := tls.X509KeyPair(certPEMData, keyPEMData)
if err != nil {
return nil, tls.Certificate{}, err
}

return certPEMData, cert, nil
}

0 comments on commit a725866

Please sign in to comment.