diff --git a/handler/oauth2/introspector_jwt.go b/handler/oauth2/introspector_jwt.go index 463842ae1..e656afbc5 100644 --- a/handler/oauth2/introspector_jwt.go +++ b/handler/oauth2/introspector_jwt.go @@ -9,6 +9,7 @@ import ( "github.com/ory/fosite" "github.com/ory/fosite/token/jwt" + "github.com/ory/x/errorsx" ) type StatelessJWTValidator struct { @@ -74,7 +75,9 @@ func (v *StatelessJWTValidator) IntrospectToken(ctx context.Context, token strin return "", err } - // TODO: From here we assume it is an access token, but how do we know it is really and that is not an ID token? + if !IsJWTProfileAccessToken(t) { + return "", errorsx.WithStack(fosite.ErrRequestUnauthorized.WithDebug("The provided token is not a valid RFC9068 JWT Profile Access Token as it is missing the header 'typ' value of 'at+jwt' ")) + } requester := AccessTokenJWTToRequest(t) @@ -86,3 +89,26 @@ func (v *StatelessJWTValidator) IntrospectToken(ctx context.Context, token strin return fosite.AccessToken, nil } + +// IsJWTProfileAccessToken validates a *jwt.Token is actually a RFC9068 JWT Profile Access Token by checking the +// relevant header as per https://datatracker.ietf.org/doc/html/rfc9068#section-2.1 which explicitly states that +// the header MUST include a typ of 'at+jwt' or 'application/at+jwt' with a preference of 'at+jwt'. +func IsJWTProfileAccessToken(token *jwt.Token) bool { + var ( + raw any + typ string + ok bool + ) + + if token == nil { + return false + } + + if raw, ok = token.Header[string(jwt.JWTHeaderType)]; !ok { + return false + } + + typ, ok = raw.(string) + + return ok && (typ == "at+jwt" || typ == "application/at+jwt") +} diff --git a/handler/oauth2/introspector_jwt_test.go b/handler/oauth2/introspector_jwt_test.go index 77482dc8e..06eb7c0c5 100644 --- a/handler/oauth2/introspector_jwt_test.go +++ b/handler/oauth2/introspector_jwt_test.go @@ -99,6 +99,23 @@ func TestIntrospectJWT(t *testing.T) { return token }, }, + { + description: "should fail bad typ", + token: func() string { + jwt := jwtValidCase(fosite.AccessToken) + + s := jwt.Session.(*JWTSession) + + s.JWTHeader.Extra["typ"] = "JWT" + + jwt.Session = s + + token, _, err := strat.GenerateAccessToken(context.Background(), jwt) + assert.NoError(t, err) + return token + }, + expectErr: fosite.ErrRequestUnauthorized, + }, } { t.Run(fmt.Sprintf("case=%d:%v", k, c.description), func(t *testing.T) { if c.scopes == nil { @@ -142,3 +159,64 @@ func BenchmarkIntrospectJWT(b *testing.B) { assert.NoError(b, err) } + +func TestIsJWTProfileAccessToken(t *testing.T) { + testCases := []struct { + name string + have *jwt.Token + expected bool + }{ + { + "ShouldPassTypATJWT", + &jwt.Token{ + Header: map[string]interface{}{ + "typ": "at+jwt", + }, + }, + true, + }, + { + "ShouldPassTypApplicationATJWT", + &jwt.Token{ + Header: map[string]interface{}{ + "typ": "application/at+jwt", + }, + }, + true, + }, + { + "ShouldFailJWT", + &jwt.Token{ + Header: map[string]interface{}{ + "typ": "JWT", + }, + }, + false, + }, + { + "ShouldFailNoValue", + &jwt.Token{ + Header: map[string]interface{}{}, + }, + false, + }, + { + "ShouldFailNilValue", + &jwt.Token{ + Header: nil, + }, + false, + }, + { + "ShouldFailNilInput", + nil, + false, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + assert.Equal(t, tc.expected, IsJWTProfileAccessToken(tc.have)) + }) + } +} diff --git a/handler/oauth2/strategy_jwt_session.go b/handler/oauth2/strategy_jwt_session.go index e6c94b30d..b845c7508 100644 --- a/handler/oauth2/strategy_jwt_session.go +++ b/handler/oauth2/strategy_jwt_session.go @@ -40,8 +40,15 @@ func (j *JWTSession) GetJWTClaims() jwt.JWTClaimsContainer { func (j *JWTSession) GetJWTHeader() *jwt.Headers { if j.JWTHeader == nil { - j.JWTHeader = &jwt.Headers{} + j.JWTHeader = &jwt.Headers{ + Extra: map[string]interface{}{ + "typ": "at+jwt", + }, + } + } else if j.JWTHeader.Extra["typ"] == nil { + j.JWTHeader.Extra["typ"] = "at+jwt" } + return j.JWTHeader } diff --git a/handler/oauth2/strategy_jwt_session_test.go b/handler/oauth2/strategy_jwt_session_test.go new file mode 100644 index 000000000..a53f50f39 --- /dev/null +++ b/handler/oauth2/strategy_jwt_session_test.go @@ -0,0 +1,46 @@ +// Copyright © 2024 Ory Corp +// SPDX-License-Identifier: Apache-2.0 + +package oauth2 + +import ( + "github.com/ory/fosite/token/jwt" + "github.com/stretchr/testify/assert" + "testing" +) + +func TestJWTSession_GetJWTHeader(t *testing.T) { + testCases := []struct { + name string + have *JWTSession + expected string + }{ + { + "ShouldReturnDefaultTyp", + &JWTSession{}, + "at+jwt", + }, + { + "ShouldReturnConfiguredATJWTTyp", + &JWTSession{JWTHeader: &jwt.Headers{Extra: map[string]interface{}{ + "typ": "at+jwt", + }}}, + "at+jwt", + }, + { + "ShouldReturnConfiguredJWTTyp", + &JWTSession{JWTHeader: &jwt.Headers{Extra: map[string]interface{}{ + "typ": "JWT", + }}}, + "JWT", + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + header := tc.have.GetJWTHeader() + + assert.Equal(t, tc.expected, header.Get("typ")) + }) + } +} diff --git a/handler/oauth2/strategy_jwt_test.go b/handler/oauth2/strategy_jwt_test.go index cdfb9ce12..4cc554aa5 100644 --- a/handler/oauth2/strategy_jwt_test.go +++ b/handler/oauth2/strategy_jwt_test.go @@ -194,6 +194,16 @@ func TestAccessToken(t *testing.T) { require.Len(t, parts, 3, "%s - %v", token, parts) assert.Equal(t, parts[2], signature) + rawHeader, err := base64.RawURLEncoding.DecodeString(parts[0]) + require.NoError(t, err) + + var header map[string]interface{} + require.NoError(t, json.Unmarshal(rawHeader, &header)) + + typ, ok := header["typ"] + assert.True(t, ok) + assert.Equal(t, "at+jwt", typ) + rawPayload, err := base64.RawURLEncoding.DecodeString(parts[1]) require.NoError(t, err) var payload map[string]interface{}