Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support multiple Issuers #12

Closed
wants to merge 3 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/pr.yml
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ jobs:
env:
LOG_PRETTY: True
LOG_LEVEL: Trace
ISSUER: "https://token.actions.githubusercontent.com"
ISSUERS: "https://token.actions.githubusercontent.com"
AUDIENCE: "https://github.com/equinor"
SUBJECTS: repo:equinor/radix-oauth-guard:pull_request,testmultiplesubjects
GH_TOKEN: ${{ steps.get-id-token.outputs.result }}
Expand Down
15 changes: 15 additions & 0 deletions .golangci.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
run:
timeout: 30m

linters:
enable:
- errcheck
- gosimple
- govet
- ineffassign
- staticcheck
- unused
- zerologlint

issues:
max-same-issues: 0
141 changes: 99 additions & 42 deletions auth.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,68 +3,125 @@ package main
import (
"context"
"errors"
"fmt"
"net/http"
"net/url"
"slices"
"strings"
"time"

"github.com/coreos/go-oidc/v3/oidc"
"github.com/rs/zerolog"
"github.com/auth0/go-jwt-middleware/v2/jwks"
"github.com/rs/zerolog/log"
"gopkg.in/go-jose/go-jose.v2/jwt"
)

var (
errInvalidAuthorizationHeader = errors.New("invalid Authorization header")
)

type Verifier interface {
Verify(ctx context.Context, rawIDToken string) (*oidc.IDToken, error)
type KeyFunc func(ctx context.Context) (interface{}, error)
type controller struct {
providers map[string]KeyFunc
audience string
subjects []string
}

// AuthHandler returns a Handler to authenticate requests
func AuthHandler(subjects []string, verifier Verifier) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
log.Trace().Func(func(e *zerolog.Event) {
headers := r.Header.Clone()
headers.Del("Authorization")
if r.Header.Get("Authorization") != "" {
headers.Set("Authorization", "!REMOVED!")
}
e.Interface("headers", headers)
}).Msg("Request details")
t := time.Now()

auth := r.Header.Get("Authorization")
jwt, err := parseAuthHeader(auth)
// NewAuthHandler returns a Handler to authenticate requests
func NewAuthHandler(audience string, subjects, issuers []string) (RouteMapper, error) {
providers := make(map[string]KeyFunc, len(issuers))
for _, issuer := range issuers {
issuerUrl, err := url.Parse(issuer)
if err != nil {
w.WriteHeader(http.StatusUnauthorized)
_, _ = w.Write([]byte("Forbidden"))
log.Info().Err(err).Dur("elappsed_ms", time.Since(t)).Int("status", http.StatusUnauthorized).Msg("Unauthorized")
return
return nil, err
}

token, err := verifier.Verify(r.Context(), jwt)
provider := jwks.NewCachingProvider(issuerUrl, 5*time.Hour)
providers[issuer] = provider.KeyFunc
}

if err != nil {
w.WriteHeader(http.StatusUnauthorized)
_, _ = w.Write([]byte("Forbidden"))
log.Info().Err(err).Dur("elappsed_ms", time.Since(t)).Int("status", http.StatusUnauthorized).Msg("Unauthorized")
return
}
c := &controller{
providers: providers,
audience: audience,
subjects: subjects,
}
return func(mux *http.ServeMux) {
mux.Handle("/auth", c)
}, nil
}

subject := token.Subject
found := slices.Contains(subjects, subject)
if !found {
w.WriteHeader(http.StatusForbidden)
_, _ = w.Write([]byte("Forbidden"))
log.Info().Err(err).Dur("elappsed_ms", time.Since(t)).Int("status", http.StatusForbidden).Str("sub", subject).Msg("Forbidden")
return
}
func (c *controller) ServeHTTP(w http.ResponseWriter, r *http.Request) {
auth := r.Header.Get("Authorization")
authHeader, err := parseAuthHeader(auth)
if err != nil {
w.WriteHeader(http.StatusUnauthorized)
_, _ = w.Write([]byte("Unauthorized"))
log.Info().Err(err).Msg("Unauthorized: Invalid auth header")
return
}

claims, err := c.getClaims(r.Context(), authHeader)
if err != nil {
w.WriteHeader(http.StatusUnauthorized)
_, _ = w.Write([]byte("Unauthorized"))
log.Warn().Err(err).Msg("Forbidden: Invalid token")
return
}

subject := claims.Subject

found := slices.Contains(c.subjects, subject)
if !found {
w.WriteHeader(http.StatusForbidden)
_, _ = w.Write([]byte("Forbidden"))
log.Warn().Str("sub", subject).Msg("Forbidden")
return
}

w.WriteHeader(http.StatusOK)
_, _ = w.Write([]byte("OK"))
log.Info().Str("sub", subject).Msg("Authorized")
}

func (c *controller) getClaims(ctx context.Context, authHeader string) (*jwt.Claims, error) {
var unsafeClaims jwt.Claims
token, err := jwt.ParseSigned(authHeader)
if err != nil {
return nil, fmt.Errorf("failed to parse JWT token: %w", err)
}
err = token.UnsafeClaimsWithoutVerification(&unsafeClaims)
if err != nil {
return nil, fmt.Errorf("failed to extract JWT unsafeClaims: %w", err)
}
var keyId string
if len(token.Headers) == 1 {
keyId = token.Headers[0].KeyID
}
if keyId == "" {
return nil, fmt.Errorf("failed to find keyId in headers")
}

issuer := unsafeClaims.Issuer
keyFunc, ok := c.providers[issuer]
if !ok {
return nil, fmt.Errorf("unknown issuer: %s", issuer)
}
key, err := keyFunc(ctx)
if err != nil {
return nil, fmt.Errorf("error getting the keys from the key func: %w", err)
}

var verifiedClaims jwt.Claims
err = token.Claims(key, &verifiedClaims)
if err != nil {
return nil, fmt.Errorf("failed to verify token unsafeClaims: %w", err)
}

expected := jwt.Expected{Audience: []string{c.audience}}
if err = verifiedClaims.Validate(expected); err != nil {
return nil, fmt.Errorf("failed to verify token unsafeClaims: %w", err)
}

w.WriteHeader(http.StatusOK)
_, _ = w.Write([]byte("OK"))
log.Info().Dur("elappsed_ms", time.Since(t)).Int("status", http.StatusOK).Str("sub", subject).Msg("Authorized")
})
return &verifiedClaims, nil
}

func parseAuthHeader(authorization string) (string, error) {
Expand Down
Loading
Loading