From 543ac159f33b0386b273877bf879811e5fd6986f Mon Sep 17 00:00:00 2001 From: rot1024 Date: Fri, 25 Aug 2023 16:29:07 +0900 Subject: [PATCH] chore(appx): detailed jwt error log --- appx/jwt.go | 31 +--------- appx/jwt_test.go | 105 ------------------------------- appx/jwt_validator.go | 81 ++++++++++++++++++++++++ appx/jwt_validator_test.go | 124 +++++++++++++++++++++++++++++++++++++ go.mod | 2 +- 5 files changed, 209 insertions(+), 134 deletions(-) create mode 100644 appx/jwt_validator.go create mode 100644 appx/jwt_validator_test.go diff --git a/appx/jwt.go b/appx/jwt.go index f4e2dffa..090dfac3 100644 --- a/appx/jwt.go +++ b/appx/jwt.go @@ -12,8 +12,6 @@ import ( "github.com/auth0/go-jwt-middleware/v2/jwks" "github.com/auth0/go-jwt-middleware/v2/validator" "github.com/golang-jwt/jwt" - "github.com/reearth/reearthx/log" - "github.com/reearth/reearthx/util" "github.com/samber/lo" ) @@ -36,7 +34,7 @@ type JWTProvider struct { TTL *int } -func (p JWTProvider) validator() (*validator.Validator, error) { +func (p JWTProvider) validator() (JWTValidator, error) { issuerURL, err := url.Parse(p.ISS) if err != nil { return nil, fmt.Errorf("failed to parse the issuer url: %w", err) @@ -64,7 +62,7 @@ func (p JWTProvider) validator() (*validator.Validator, error) { aud = []string{} } - return validator.New( + return NewJWTValidatorWithError( provider.KeyFunc, algorithm, issuerURL.String(), @@ -86,29 +84,6 @@ func (c *customClaims) Validate(_ context.Context) error { return nil } -type MultiValidator []*validator.Validator - -func NewMultiValidator(providers []JWTProvider) (MultiValidator, error) { - return util.TryMap(providers, func(p JWTProvider) (*validator.Validator, error) { - return p.validator() - }) -} - -// ValidateToken Trys to validate the token with each validator -// NOTE: the last validation error only is returned -func (mv MultiValidator) ValidateToken(ctx context.Context, tokenString string) (res interface{}, err error) { - for _, v := range mv { - res, err = v.ValidateToken(ctx, tokenString) - if err == nil { - return - } - } - - log.Errorfc(ctx, "auth: invalid JWT: %v", err) - log.Debugfc(ctx, "auth: JWT: %s", tokenString) - return -} - // AuthInfoMiddleware loads claim from context and attach the user info. func AuthInfoMiddleware(key any) func(http.Handler) http.Handler { return func(next http.Handler) http.Handler { @@ -139,7 +114,7 @@ func AuthInfoMiddleware(key any) func(http.Handler) http.Handler { } func AuthMiddleware(providers []JWTProvider, key any, optional bool) (func(http.Handler) http.Handler, error) { - v, err := NewMultiValidator(providers) + v, err := NewJWTMultipleValidator(providers) if err != nil { return nil, err } diff --git a/appx/jwt_test.go b/appx/jwt_test.go index a87cadb5..40d4835c 100644 --- a/appx/jwt_test.go +++ b/appx/jwt_test.go @@ -20,113 +20,8 @@ import ( "github.com/stretchr/testify/assert" "golang.org/x/exp/slices" "gopkg.in/square/go-jose.v2" - jwt2 "gopkg.in/square/go-jose.v2/jwt" ) -func TestMultiValidator(t *testing.T) { - key := lo.Must(rsa.GenerateKey(rand.Reader, 2048)) - - httpmock.Activate() - defer httpmock.DeactivateAndReset() - httpmock.RegisterResponder( - http.MethodGet, - "https://example.com/.well-known/openid-configuration", - util.DR(httpmock.NewJsonResponder(http.StatusOK, map[string]string{"jwks_uri": "https://example.com/jwks"})), - ) - httpmock.RegisterResponder( - http.MethodGet, - "https://example2.com/.well-known/openid-configuration", - util.DR(httpmock.NewJsonResponder(http.StatusOK, map[string]string{"jwks_uri": "https://example.com/jwks"})), - ) - httpmock.RegisterResponder( - http.MethodGet, - "https://example.com/jwks", - httpmock.NewBytesResponder(http.StatusOK, lo.Must(json.Marshal(jose.JSONWebKeySet{ - Keys: []jose.JSONWebKey{ - {KeyID: "0", Key: &key.PublicKey, Algorithm: jwt.SigningMethodRS256.Name}, - }, - }))), - ) - - v, err := NewMultiValidator([]JWTProvider{ - {ISS: "https://example.com/", AUD: []string{"a", "b"}, ALG: &jwt.SigningMethodRS256.Name}, - {ISS: "https://example2.com/", AUD: []string{"c"}, ALG: &jwt.SigningMethodRS256.Name}, - }) - assert.NoError(t, err) - - expiry := time.Now().Add(time.Hour * 24).Unix() - claims := jwt.MapClaims{ - "exp": expiry, - "iss": "https://example.com/", - "sub": "subsub", - "aud": []string{"a", "b"}, - "name": "aaa", - "nickname": "bbb", - "email": "ccc", - "email_verified": true, - } - token := jwt.NewWithClaims(jwt.SigningMethodRS256, claims) - token.Header["kid"] = "0" - tokenString := lo.Must(token.SignedString(key)) - - claims2 := jwt.MapClaims{ - "exp": expiry, - "iss": "https://example2.com/", - "sub": "subsub2", - "aud": "c", - "name": "aaa", - "nickname": "bbb", - } - token2 := jwt.NewWithClaims(jwt.SigningMethodRS256, claims2) - token2.Header["kid"] = "0" - tokenString2 := lo.Must(token2.SignedString(key)) - - claims3 := jwt.MapClaims{ - "exp": expiry, - "iss": "https://example3.com/", - "aud": "c", - } - token3 := jwt.NewWithClaims(jwt.SigningMethodRS256, claims3) - token3.Header["kid"] = "0" - tokenString3 := lo.Must(token3.SignedString(key)) - - res, err := v.ValidateToken(context.Background(), tokenString) - assert.NoError(t, err) - assert.Equal(t, &validator.ValidatedClaims{ - CustomClaims: &customClaims{ - Name: "aaa", - Nickname: "bbb", - Email: "ccc", - EmailVerified: lo.ToPtr(true), - }, - RegisteredClaims: validator.RegisteredClaims{ - Issuer: "https://example.com/", - Subject: "subsub", - Audience: []string{"a", "b"}, - Expiry: expiry, - }, - }, res) - - res2, err := v.ValidateToken(context.Background(), tokenString2) - assert.NoError(t, err) - assert.Equal(t, &validator.ValidatedClaims{ - CustomClaims: &customClaims{ - Name: "aaa", - Nickname: "bbb", - }, - RegisteredClaims: validator.RegisteredClaims{ - Issuer: "https://example2.com/", - Subject: "subsub2", - Audience: []string{"c"}, - Expiry: expiry, - }, - }, res2) - - res3, err := v.ValidateToken(context.Background(), tokenString3) - assert.ErrorIs(t, err, jwt2.ErrInvalidIssuer) - assert.Nil(t, res3) -} - func TestAuthInfoMiddleware(t *testing.T) { key := struct{}{} m := AuthInfoMiddleware(key) diff --git a/appx/jwt_validator.go b/appx/jwt_validator.go new file mode 100644 index 00000000..03feecdd --- /dev/null +++ b/appx/jwt_validator.go @@ -0,0 +1,81 @@ +package appx + +import ( + "context" + "errors" + "fmt" + + "github.com/auth0/go-jwt-middleware/v2/validator" + "github.com/reearth/reearthx/log" + "github.com/reearth/reearthx/util" + "golang.org/x/exp/slices" +) + +type JWTValidator interface { + ValidateToken(ctx context.Context, tokenString string) (interface{}, error) +} + +// JWTValidatorWithError wraps "validator.Validator and attach iss and aud to the error message to make it easy to track errors. +type JWTValidatorWithError struct { + validator *validator.Validator + iss string + aud []string +} + +func NewJWTValidatorWithError( + keyFunc func(context.Context) (interface{}, error), + signatureAlgorithm validator.SignatureAlgorithm, + issuerURL string, + audience []string, + opts ...validator.Option, +) (*JWTValidatorWithError, error) { + validator, err := validator.New( + keyFunc, + signatureAlgorithm, + issuerURL, + audience, + opts..., + ) + if err != nil { + return nil, err + } + return &JWTValidatorWithError{ + validator: validator, + iss: issuerURL, + aud: slices.Clone(audience), + }, nil +} + +func (v *JWTValidatorWithError) ValidateToken(ctx context.Context, token string) (interface{}, error) { + res, err := v.validator.ValidateToken(ctx, token) + if err != nil { + err = fmt.Errorf("invalid JWT: iss=%s aud=%v err=%w", v.iss, v.aud, err) + } + return res, err +} + +type JWTMultipleValidator []JWTValidator + +func NewJWTMultipleValidator(providers []JWTProvider) (JWTMultipleValidator, error) { + return util.TryMap(providers, func(p JWTProvider) (JWTValidator, error) { + return p.validator() + }) +} + +// ValidateToken Trys to validate the token with each validator +// NOTE: the last validation error only is returned +func (mv JWTMultipleValidator) ValidateToken(ctx context.Context, tokenString string) (res interface{}, err error) { + for _, v := range mv { + var err2 error + res, err2 = v.ValidateToken(ctx, tokenString) + if err2 == nil { + err = nil + return + } + err = errors.Join(err, err2) + } + + log.Debugfc(ctx, "auth: invalid JWT token: %s", tokenString) + log.Errorfc(ctx, "auth: invalid JWT token: %v", err) + return +} diff --git a/appx/jwt_validator_test.go b/appx/jwt_validator_test.go new file mode 100644 index 00000000..189ba351 --- /dev/null +++ b/appx/jwt_validator_test.go @@ -0,0 +1,124 @@ +package appx + +import ( + "context" + "crypto/rand" + "crypto/rsa" + "encoding/json" + "net/http" + "testing" + "time" + + "github.com/auth0/go-jwt-middleware/v2/validator" + "github.com/golang-jwt/jwt" + "github.com/jarcoal/httpmock" + "github.com/reearth/reearthx/util" + "github.com/samber/lo" + "github.com/stretchr/testify/assert" + "gopkg.in/square/go-jose.v2" + jwt2 "gopkg.in/square/go-jose.v2/jwt" +) + +func TestMultiValidator(t *testing.T) { + key := lo.Must(rsa.GenerateKey(rand.Reader, 2048)) + + httpmock.Activate() + defer httpmock.DeactivateAndReset() + httpmock.RegisterResponder( + http.MethodGet, + "https://example.com/.well-known/openid-configuration", + util.DR(httpmock.NewJsonResponder(http.StatusOK, map[string]string{"jwks_uri": "https://example.com/jwks"})), + ) + httpmock.RegisterResponder( + http.MethodGet, + "https://example2.com/.well-known/openid-configuration", + util.DR(httpmock.NewJsonResponder(http.StatusOK, map[string]string{"jwks_uri": "https://example.com/jwks"})), + ) + httpmock.RegisterResponder( + http.MethodGet, + "https://example.com/jwks", + httpmock.NewBytesResponder(http.StatusOK, lo.Must(json.Marshal(jose.JSONWebKeySet{ + Keys: []jose.JSONWebKey{ + {KeyID: "0", Key: &key.PublicKey, Algorithm: jwt.SigningMethodRS256.Name}, + }, + }))), + ) + + v, err := NewJWTMultipleValidator([]JWTProvider{ + {ISS: "https://example.com/", AUD: []string{"a", "b"}, ALG: &jwt.SigningMethodRS256.Name}, + {ISS: "https://example2.com/", AUD: []string{"c"}, ALG: &jwt.SigningMethodRS256.Name}, + }) + assert.NoError(t, err) + + expiry := time.Now().Add(time.Hour * 24).Unix() + claims := jwt.MapClaims{ + "exp": expiry, + "iss": "https://example.com/", + "sub": "subsub", + "aud": []string{"a", "b"}, + "name": "aaa", + "nickname": "bbb", + "email": "ccc", + "email_verified": true, + } + token := jwt.NewWithClaims(jwt.SigningMethodRS256, claims) + token.Header["kid"] = "0" + tokenString := lo.Must(token.SignedString(key)) + + claims2 := jwt.MapClaims{ + "exp": expiry, + "iss": "https://example2.com/", + "sub": "subsub2", + "aud": "c", + "name": "aaa", + "nickname": "bbb", + } + token2 := jwt.NewWithClaims(jwt.SigningMethodRS256, claims2) + token2.Header["kid"] = "0" + tokenString2 := lo.Must(token2.SignedString(key)) + + claims3 := jwt.MapClaims{ + "exp": expiry, + "iss": "https://example3.com/", + "aud": "c", + } + token3 := jwt.NewWithClaims(jwt.SigningMethodRS256, claims3) + token3.Header["kid"] = "0" + tokenString3 := lo.Must(token3.SignedString(key)) + + res, err := v.ValidateToken(context.Background(), tokenString) + assert.NoError(t, err) + assert.Equal(t, &validator.ValidatedClaims{ + CustomClaims: &customClaims{ + Name: "aaa", + Nickname: "bbb", + Email: "ccc", + EmailVerified: lo.ToPtr(true), + }, + RegisteredClaims: validator.RegisteredClaims{ + Issuer: "https://example.com/", + Subject: "subsub", + Audience: []string{"a", "b"}, + Expiry: expiry, + }, + }, res) + + res2, err := v.ValidateToken(context.Background(), tokenString2) + assert.NoError(t, err) + assert.Equal(t, &validator.ValidatedClaims{ + CustomClaims: &customClaims{ + Name: "aaa", + Nickname: "bbb", + }, + RegisteredClaims: validator.RegisteredClaims{ + Issuer: "https://example2.com/", + Subject: "subsub2", + Audience: []string{"c"}, + Expiry: expiry, + }, + }, res2) + + res3, err := v.ValidateToken(context.Background(), tokenString3) + assert.ErrorIs(t, err, jwt2.ErrInvalidIssuer) + assert.Nil(t, res3) +} diff --git a/go.mod b/go.mod index 9ef317cf..3e822461 100644 --- a/go.mod +++ b/go.mod @@ -1,6 +1,6 @@ module github.com/reearth/reearthx -go 1.19 +go 1.20 require ( github.com/99designs/gqlgen v0.17.12