Skip to content

Commit

Permalink
chore(appx): detailed jwt error log
Browse files Browse the repository at this point in the history
  • Loading branch information
rot1024 committed Aug 25, 2023
1 parent 20e5d57 commit 543ac15
Show file tree
Hide file tree
Showing 5 changed files with 209 additions and 134 deletions.
31 changes: 3 additions & 28 deletions appx/jwt.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,6 @@ import (
"github.com/auth0/go-jwt-middleware/v2/jwks"
"github.com/auth0/go-jwt-middleware/v2/validator"
"github.com/golang-jwt/jwt"
"github.com/reearth/reearthx/log"
"github.com/reearth/reearthx/util"
"github.com/samber/lo"
)

Expand All @@ -36,7 +34,7 @@ type JWTProvider struct {
TTL *int
}

func (p JWTProvider) validator() (*validator.Validator, error) {
func (p JWTProvider) validator() (JWTValidator, error) {
issuerURL, err := url.Parse(p.ISS)
if err != nil {
return nil, fmt.Errorf("failed to parse the issuer url: %w", err)
Expand Down Expand Up @@ -64,7 +62,7 @@ func (p JWTProvider) validator() (*validator.Validator, error) {
aud = []string{}
}

return validator.New(
return NewJWTValidatorWithError(
provider.KeyFunc,
algorithm,
issuerURL.String(),
Expand All @@ -86,29 +84,6 @@ func (c *customClaims) Validate(_ context.Context) error {
return nil
}

type MultiValidator []*validator.Validator

func NewMultiValidator(providers []JWTProvider) (MultiValidator, error) {
return util.TryMap(providers, func(p JWTProvider) (*validator.Validator, error) {
return p.validator()
})
}

// ValidateToken Trys to validate the token with each validator
// NOTE: the last validation error only is returned
func (mv MultiValidator) ValidateToken(ctx context.Context, tokenString string) (res interface{}, err error) {
for _, v := range mv {
res, err = v.ValidateToken(ctx, tokenString)
if err == nil {
return
}
}

log.Errorfc(ctx, "auth: invalid JWT: %v", err)
log.Debugfc(ctx, "auth: JWT: %s", tokenString)
return
}

// AuthInfoMiddleware loads claim from context and attach the user info.
func AuthInfoMiddleware(key any) func(http.Handler) http.Handler {
return func(next http.Handler) http.Handler {
Expand Down Expand Up @@ -139,7 +114,7 @@ func AuthInfoMiddleware(key any) func(http.Handler) http.Handler {
}

func AuthMiddleware(providers []JWTProvider, key any, optional bool) (func(http.Handler) http.Handler, error) {
v, err := NewMultiValidator(providers)
v, err := NewJWTMultipleValidator(providers)
if err != nil {
return nil, err
}
Expand Down
105 changes: 0 additions & 105 deletions appx/jwt_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,113 +20,8 @@ import (
"github.com/stretchr/testify/assert"
"golang.org/x/exp/slices"
"gopkg.in/square/go-jose.v2"
jwt2 "gopkg.in/square/go-jose.v2/jwt"
)

func TestMultiValidator(t *testing.T) {
key := lo.Must(rsa.GenerateKey(rand.Reader, 2048))

httpmock.Activate()
defer httpmock.DeactivateAndReset()
httpmock.RegisterResponder(
http.MethodGet,
"https://example.com/.well-known/openid-configuration",
util.DR(httpmock.NewJsonResponder(http.StatusOK, map[string]string{"jwks_uri": "https://example.com/jwks"})),
)
httpmock.RegisterResponder(
http.MethodGet,
"https://example2.com/.well-known/openid-configuration",
util.DR(httpmock.NewJsonResponder(http.StatusOK, map[string]string{"jwks_uri": "https://example.com/jwks"})),
)
httpmock.RegisterResponder(
http.MethodGet,
"https://example.com/jwks",
httpmock.NewBytesResponder(http.StatusOK, lo.Must(json.Marshal(jose.JSONWebKeySet{
Keys: []jose.JSONWebKey{
{KeyID: "0", Key: &key.PublicKey, Algorithm: jwt.SigningMethodRS256.Name},
},
}))),
)

v, err := NewMultiValidator([]JWTProvider{
{ISS: "https://example.com/", AUD: []string{"a", "b"}, ALG: &jwt.SigningMethodRS256.Name},
{ISS: "https://example2.com/", AUD: []string{"c"}, ALG: &jwt.SigningMethodRS256.Name},
})
assert.NoError(t, err)

expiry := time.Now().Add(time.Hour * 24).Unix()
claims := jwt.MapClaims{
"exp": expiry,
"iss": "https://example.com/",
"sub": "subsub",
"aud": []string{"a", "b"},
"name": "aaa",
"nickname": "bbb",
"email": "ccc",
"email_verified": true,
}
token := jwt.NewWithClaims(jwt.SigningMethodRS256, claims)
token.Header["kid"] = "0"
tokenString := lo.Must(token.SignedString(key))

claims2 := jwt.MapClaims{
"exp": expiry,
"iss": "https://example2.com/",
"sub": "subsub2",
"aud": "c",
"name": "aaa",
"nickname": "bbb",
}
token2 := jwt.NewWithClaims(jwt.SigningMethodRS256, claims2)
token2.Header["kid"] = "0"
tokenString2 := lo.Must(token2.SignedString(key))

claims3 := jwt.MapClaims{
"exp": expiry,
"iss": "https://example3.com/",
"aud": "c",
}
token3 := jwt.NewWithClaims(jwt.SigningMethodRS256, claims3)
token3.Header["kid"] = "0"
tokenString3 := lo.Must(token3.SignedString(key))

res, err := v.ValidateToken(context.Background(), tokenString)
assert.NoError(t, err)
assert.Equal(t, &validator.ValidatedClaims{
CustomClaims: &customClaims{
Name: "aaa",
Nickname: "bbb",
Email: "ccc",
EmailVerified: lo.ToPtr(true),
},
RegisteredClaims: validator.RegisteredClaims{
Issuer: "https://example.com/",
Subject: "subsub",
Audience: []string{"a", "b"},
Expiry: expiry,
},
}, res)

res2, err := v.ValidateToken(context.Background(), tokenString2)
assert.NoError(t, err)
assert.Equal(t, &validator.ValidatedClaims{
CustomClaims: &customClaims{
Name: "aaa",
Nickname: "bbb",
},
RegisteredClaims: validator.RegisteredClaims{
Issuer: "https://example2.com/",
Subject: "subsub2",
Audience: []string{"c"},
Expiry: expiry,
},
}, res2)

res3, err := v.ValidateToken(context.Background(), tokenString3)
assert.ErrorIs(t, err, jwt2.ErrInvalidIssuer)
assert.Nil(t, res3)
}

func TestAuthInfoMiddleware(t *testing.T) {
key := struct{}{}
m := AuthInfoMiddleware(key)
Expand Down
81 changes: 81 additions & 0 deletions appx/jwt_validator.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
package appx

import (
"context"
"errors"
"fmt"

"github.com/auth0/go-jwt-middleware/v2/validator"
"github.com/reearth/reearthx/log"
"github.com/reearth/reearthx/util"
"golang.org/x/exp/slices"
)

type JWTValidator interface {
ValidateToken(ctx context.Context, tokenString string) (interface{}, error)
}

// JWTValidatorWithError wraps "validator.Validator and attach iss and aud to the error message to make it easy to track errors.
type JWTValidatorWithError struct {
validator *validator.Validator
iss string
aud []string
}

func NewJWTValidatorWithError(
keyFunc func(context.Context) (interface{}, error),
signatureAlgorithm validator.SignatureAlgorithm,
issuerURL string,
audience []string,
opts ...validator.Option,
) (*JWTValidatorWithError, error) {
validator, err := validator.New(
keyFunc,
signatureAlgorithm,
issuerURL,
audience,
opts...,
)
if err != nil {
return nil, err
}
return &JWTValidatorWithError{
validator: validator,
iss: issuerURL,
aud: slices.Clone(audience),
}, nil
}

func (v *JWTValidatorWithError) ValidateToken(ctx context.Context, token string) (interface{}, error) {
res, err := v.validator.ValidateToken(ctx, token)
if err != nil {
err = fmt.Errorf("invalid JWT: iss=%s aud=%v err=%w", v.iss, v.aud, err)
}
return res, err
}

type JWTMultipleValidator []JWTValidator

func NewJWTMultipleValidator(providers []JWTProvider) (JWTMultipleValidator, error) {
return util.TryMap(providers, func(p JWTProvider) (JWTValidator, error) {
return p.validator()
})
}

// ValidateToken Trys to validate the token with each validator
// NOTE: the last validation error only is returned
func (mv JWTMultipleValidator) ValidateToken(ctx context.Context, tokenString string) (res interface{}, err error) {
for _, v := range mv {
var err2 error
res, err2 = v.ValidateToken(ctx, tokenString)
if err2 == nil {
err = nil
return
}
err = errors.Join(err, err2)
}

log.Debugfc(ctx, "auth: invalid JWT token: %s", tokenString)
log.Errorfc(ctx, "auth: invalid JWT token: %v", err)
return
}
124 changes: 124 additions & 0 deletions appx/jwt_validator_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,124 @@
package appx

import (
"context"
"crypto/rand"
"crypto/rsa"
"encoding/json"
"net/http"
"testing"
"time"

"github.com/auth0/go-jwt-middleware/v2/validator"
"github.com/golang-jwt/jwt"
"github.com/jarcoal/httpmock"
"github.com/reearth/reearthx/util"
"github.com/samber/lo"
"github.com/stretchr/testify/assert"
"gopkg.in/square/go-jose.v2"
jwt2 "gopkg.in/square/go-jose.v2/jwt"
)

func TestMultiValidator(t *testing.T) {
key := lo.Must(rsa.GenerateKey(rand.Reader, 2048))

httpmock.Activate()
defer httpmock.DeactivateAndReset()
httpmock.RegisterResponder(
http.MethodGet,
"https://example.com/.well-known/openid-configuration",
util.DR(httpmock.NewJsonResponder(http.StatusOK, map[string]string{"jwks_uri": "https://example.com/jwks"})),
)
httpmock.RegisterResponder(
http.MethodGet,
"https://example2.com/.well-known/openid-configuration",
util.DR(httpmock.NewJsonResponder(http.StatusOK, map[string]string{"jwks_uri": "https://example.com/jwks"})),
)
httpmock.RegisterResponder(
http.MethodGet,
"https://example.com/jwks",
httpmock.NewBytesResponder(http.StatusOK, lo.Must(json.Marshal(jose.JSONWebKeySet{
Keys: []jose.JSONWebKey{
{KeyID: "0", Key: &key.PublicKey, Algorithm: jwt.SigningMethodRS256.Name},
},
}))),
)

v, err := NewJWTMultipleValidator([]JWTProvider{
{ISS: "https://example.com/", AUD: []string{"a", "b"}, ALG: &jwt.SigningMethodRS256.Name},
{ISS: "https://example2.com/", AUD: []string{"c"}, ALG: &jwt.SigningMethodRS256.Name},
})
assert.NoError(t, err)

expiry := time.Now().Add(time.Hour * 24).Unix()
claims := jwt.MapClaims{
"exp": expiry,
"iss": "https://example.com/",
"sub": "subsub",
"aud": []string{"a", "b"},
"name": "aaa",
"nickname": "bbb",
"email": "ccc",
"email_verified": true,
}
token := jwt.NewWithClaims(jwt.SigningMethodRS256, claims)
token.Header["kid"] = "0"
tokenString := lo.Must(token.SignedString(key))

claims2 := jwt.MapClaims{
"exp": expiry,
"iss": "https://example2.com/",
"sub": "subsub2",
"aud": "c",
"name": "aaa",
"nickname": "bbb",
}
token2 := jwt.NewWithClaims(jwt.SigningMethodRS256, claims2)
token2.Header["kid"] = "0"
tokenString2 := lo.Must(token2.SignedString(key))

claims3 := jwt.MapClaims{
"exp": expiry,
"iss": "https://example3.com/",
"aud": "c",
}
token3 := jwt.NewWithClaims(jwt.SigningMethodRS256, claims3)
token3.Header["kid"] = "0"
tokenString3 := lo.Must(token3.SignedString(key))

res, err := v.ValidateToken(context.Background(), tokenString)
assert.NoError(t, err)
assert.Equal(t, &validator.ValidatedClaims{
CustomClaims: &customClaims{
Name: "aaa",
Nickname: "bbb",
Email: "ccc",
EmailVerified: lo.ToPtr(true),
},
RegisteredClaims: validator.RegisteredClaims{
Issuer: "https://example.com/",
Subject: "subsub",
Audience: []string{"a", "b"},
Expiry: expiry,
},
}, res)

res2, err := v.ValidateToken(context.Background(), tokenString2)
assert.NoError(t, err)
assert.Equal(t, &validator.ValidatedClaims{
CustomClaims: &customClaims{
Name: "aaa",
Nickname: "bbb",
},
RegisteredClaims: validator.RegisteredClaims{
Issuer: "https://example2.com/",
Subject: "subsub2",
Audience: []string{"c"},
Expiry: expiry,
},
}, res2)

res3, err := v.ValidateToken(context.Background(), tokenString3)
assert.ErrorIs(t, err, jwt2.ErrInvalidIssuer)
assert.Nil(t, res3)
}
2 changes: 1 addition & 1 deletion go.mod
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
module github.com/reearth/reearthx

go 1.19
go 1.20

require (
github.com/99designs/gqlgen v0.17.12
Expand Down

0 comments on commit 543ac15

Please sign in to comment.