diff --git a/src/autoscaler/helpers/auth/auth_suite_test.go b/src/autoscaler/helpers/auth/auth_suite_test.go new file mode 100644 index 0000000000..cc266fc4ed --- /dev/null +++ b/src/autoscaler/helpers/auth/auth_suite_test.go @@ -0,0 +1,13 @@ +package auth_test + +import ( + "testing" + + . "github.com/onsi/ginkgo/v2" + . "github.com/onsi/gomega" +) + +func TestAuth(t *testing.T) { + RegisterFailHandler(Fail) + RunSpecs(t, "Auth Suite") +} diff --git a/src/autoscaler/helpers/auth/xfcc_auth.go b/src/autoscaler/helpers/auth/xfcc_auth.go new file mode 100644 index 0000000000..2728cb7b5d --- /dev/null +++ b/src/autoscaler/helpers/auth/xfcc_auth.go @@ -0,0 +1,118 @@ +package auth + +import ( + "crypto/x509" + "encoding/base64" + "errors" + "fmt" + "net/http" + "regexp" + "strings" + + "code.cloudfoundry.org/lager/v3" +) + +var ErrorWrongSpace = errors.New("space guid is wrong") +var ErrorWrongOrg = errors.New("org guid is wrong") +var ErrXFCCHeaderNotFound = errors.New("xfcc header not found") + +type XFCCAuthMiddleware struct { + logger lager.Logger + spaceGuid string + orgGuid string +} + +func (m *XFCCAuthMiddleware) checkAuth(r *http.Request) error { + xfccHeader := r.Header.Get("X-Forwarded-Client-Cert") + if xfccHeader == "" { + return ErrXFCCHeaderNotFound + } + + data, err := base64.StdEncoding.DecodeString(removeQuotes(xfccHeader)) + if err != nil { + return fmt.Errorf("base64 parsing failed: %w", err) + } + + cert, err := x509.ParseCertificate(data) + if err != nil { + return fmt.Errorf("failed to parse certificate: %w", err) + } + + if getSpaceGuid(cert) != m.spaceGuid { + return ErrorWrongSpace + } + + if getOrgGuid(cert) != m.orgGuid { + return ErrorWrongOrg + } + + return nil + +} + +func (m *XFCCAuthMiddleware) XFCCAuthenticationMiddleware(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + err := m.checkAuth(r) + + if err != nil { + m.logger.Error("xfcc-auth-error", err) + http.Error(w, http.StatusText(http.StatusUnauthorized), http.StatusUnauthorized) + return + } + + next.ServeHTTP(w, r) + }) +} + +func NewXfccAuthMiddleware(logger lager.Logger, orgGuid, spaceGuid string) *XFCCAuthMiddleware { + return &XFCCAuthMiddleware{ + logger: logger, + orgGuid: orgGuid, + spaceGuid: spaceGuid, + } +} + +func getSpaceGuid(cert *x509.Certificate) string { + var certSpaceGuid string + for _, ou := range cert.Subject.OrganizationalUnit { + + if strings.Contains(ou, "space:") { + kv := mapFrom(ou) + certSpaceGuid = kv["space"] + break + } + } + return certSpaceGuid +} + +func mapFrom(input string) map[string]string { + result := make(map[string]string) + + r := regexp.MustCompile(`(\w+):(\w+-\w+)`) + matches := r.FindAllStringSubmatch(input, -1) + + for _, match := range matches { + result[match[1]] = match[2] + } + return result +} + +func getOrgGuid(cert *x509.Certificate) string { + var certOrgGuid string + for _, ou := range cert.Subject.OrganizationalUnit { + // capture from string k:v with regex + if strings.Contains(ou, "org:") { + kv := mapFrom(ou) + certOrgGuid = kv["org"] + break + } + } + return certOrgGuid +} + +func removeQuotes(xfccHeader string) string { + if xfccHeader[0] == '"' { + xfccHeader = xfccHeader[1 : len(xfccHeader)-1] + } + return xfccHeader +} diff --git a/src/autoscaler/helpers/auth/xfcc_auth_test.go b/src/autoscaler/helpers/auth/xfcc_auth_test.go new file mode 100644 index 0000000000..6a5a54ac87 --- /dev/null +++ b/src/autoscaler/helpers/auth/xfcc_auth_test.go @@ -0,0 +1,153 @@ +package auth_test + +import ( + "crypto/rand" + "crypto/rsa" + "crypto/x509" + "crypto/x509/pkix" + "encoding/base64" + "encoding/pem" + "fmt" + "math/big" + "net/http" + "net/http/httptest" + + "code.cloudfoundry.org/app-autoscaler/src/autoscaler/helpers/auth" + + "code.cloudfoundry.org/lager/v3/lagertest" + . "github.com/onsi/ginkgo/v2" + . "github.com/onsi/gomega" + "github.com/onsi/gomega/gbytes" +) + +var handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) +}) + +var _ = Describe("XfccAuthMiddleware", func() { + var ( + server *httptest.Server + resp *http.Response + + buffer *gbytes.Buffer + + err error + xfccClientCert []byte + + orgGuid string + spaceGuid string + ) + + AfterEach(func() { + server.Close() + }) + + JustBeforeEach(func() { + logger := lagertest.NewTestLogger("xfcc-auth-test") + buffer = logger.Buffer() + xm := auth.NewXfccAuthMiddleware(logger, orgGuid, spaceGuid) + + server = httptest.NewServer(xm.XFCCAuthenticationMiddleware(handler)) + + req, err := http.NewRequest("GET", server.URL+"/some-protected-endpoint", nil) + + if len(xfccClientCert) > 0 { + block, _ := pem.Decode(xfccClientCert) + Expect(err).NotTo(HaveOccurred()) + Expect(block).ShouldNot(BeNil()) + + req.Header.Add("X-Forwarded-Client-Cert", base64.StdEncoding.EncodeToString(block.Bytes)) + } + Expect(err).NotTo(HaveOccurred()) + + resp, err = http.DefaultClient.Do(req) + Expect(err).NotTo(HaveOccurred()) + }) + + BeforeEach(func() { + orgGuid = "org-guid" + spaceGuid = "space-guid" + }) + + When("xfcc header is not set", func() { + BeforeEach(func() { + xfccClientCert = []byte{} + }) + + It("should return 401", func() { + Expect(resp.StatusCode).To(Equal(http.StatusUnauthorized)) + Eventually(buffer).Should(gbytes.Say(auth.ErrXFCCHeaderNotFound.Error())) + }) + }) + + When("xfcc cert matches org and space guids", func() { + BeforeEach(func() { + xfccClientCert, err = generateClientCert(orgGuid, spaceGuid) + Expect(err).NotTo(HaveOccurred()) + }) + + It("should return 200", func() { + Expect(resp.StatusCode).To(Equal(http.StatusOK)) + }) + }) + + When("xfcc cert does not match org guid", func() { + BeforeEach(func() { + xfccClientCert, err = generateClientCert("wrong-org-guid", spaceGuid) + Expect(err).NotTo(HaveOccurred()) + }) + + It("should return 401", func() { + Eventually(buffer).Should(gbytes.Say(auth.ErrorWrongOrg.Error())) + Expect(resp.StatusCode).To(Equal(http.StatusUnauthorized)) + }) + + }) + + When("xfcc cert does not match space guid", func() { + BeforeEach(func() { + xfccClientCert, err = generateClientCert(orgGuid, "wrong-space-guid") + Expect(err).NotTo(HaveOccurred()) + }) + + It("should return 401", func() { + Expect(resp.StatusCode).To(Equal(http.StatusUnauthorized)) + Eventually(buffer).Should(gbytes.Say(auth.ErrorWrongSpace.Error())) + }) + }) +}) + +// generateClientCert generates a client certificate with the specified spaceGUID and orgGUID +// included in the organizational unit string. +func generateClientCert(orgGUID, spaceGUID string) ([]byte, error) { + // Generate a random serial number for the certificate + // + serialNumber, err := rand.Int(rand.Reader, new(big.Int).Lsh(big.NewInt(1), 128)) + if err != nil { + return nil, err + } + + privateKey, err := rsa.GenerateKey(rand.Reader, 2048) + if err != nil { + return nil, err + } + + // Create a new X.509 certificate template + template := x509.Certificate{ + SerialNumber: serialNumber, + Subject: pkix.Name{ + Organization: []string{"My Organization"}, + OrganizationalUnit: []string{fmt.Sprintf("space:%s org:%s", spaceGUID, orgGUID)}, + }, + } + // Generate the certificate + certDER, err := x509.CreateCertificate(rand.Reader, &template, &template, &privateKey.PublicKey, privateKey) + if err != nil { + return nil, err + } + + // Encode the certificate to PEM format + certPEM := pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: certDER}) + + return certPEM, nil +}