diff --git a/access_request.go b/access_request.go index 2ef86c3f1..bd823c871 100644 --- a/access_request.go +++ b/access_request.go @@ -7,6 +7,8 @@ type AccessRequest struct { GrantTypes Arguments `json:"grantTypes" gorethink:"grantTypes"` HandledGrantType Arguments `json:"handledGrantType" gorethink:"handledGrantType"` + JWTClaims map[string]interface{} `json:"jwt_claims" gorethink:"jwtClaims"` + Request } @@ -23,3 +25,11 @@ func NewAccessRequest(session Session) *AccessRequest { func (a *AccessRequest) GetGrantTypes() Arguments { return a.GrantTypes } + +func (a *AccessRequest) GetJWTClaims() map[string]interface{} { + return a.JWTClaims +} + +func (a *AccessRequest) SetJWTClaims(claims map[string]interface{}) { + a.JWTClaims = claims +} diff --git a/client_authentication.go b/client_authentication.go index 82f3c3872..d0e6d4cdf 100644 --- a/client_authentication.go +++ b/client_authentication.go @@ -187,6 +187,10 @@ func (f *Fosite) DefaultClientAuthenticationStrategy(ctx context.Context, r *htt strings.Join(f.Config.GetTokenURLs(ctx), "' or '"))) } + if accessRequest, ok := ctx.Value(AccessRequestContextKey).(*AccessRequest); ok { + accessRequest.SetJWTClaims(claims) + } + return client, nil } else if len(assertionType) > 0 { return nil, errorsx.WithStack(ErrInvalidRequest.WithHintf("Unknown client_assertion_type '%s'.", assertionType)) diff --git a/handler/rfc7523/handler.go b/handler/rfc7523/handler.go index 4c7767e8a..db768e983 100644 --- a/handler/rfc7523/handler.go +++ b/handler/rfc7523/handler.go @@ -70,7 +70,8 @@ func (c *Handler) HandleTokenEndpointRequest(ctx context.Context, request fosite } claims := jwt.Claims{} - if err := token.Claims(key, &claims); err != nil { + rawClaims := make(map[string]interface{}) + if err := token.Claims(key, &claims, &rawClaims); err != nil { return errorsx.WithStack(fosite.ErrInvalidGrant. WithHint("Unable to verify the integrity of the 'assertion' value."). WithWrap(err).WithDebug(err.Error()), @@ -115,6 +116,8 @@ func (c *Handler) HandleTokenEndpointRequest(ctx context.Context, request fosite session.SetExpiresAt(fosite.AccessToken, time.Now().UTC().Add(atLifespan).Round(time.Second)) session.SetSubject(claims.Subject) + request.SetJWTClaims(rawClaims) + return nil } diff --git a/internal/access_request.go b/internal/access_request.go index d2f42e326..43b43c5d1 100644 --- a/internal/access_request.go +++ b/internal/access_request.go @@ -121,6 +121,20 @@ func (mr *MockAccessRequesterMockRecorder) GetID() *gomock.Call { return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetID", reflect.TypeOf((*MockAccessRequester)(nil).GetID)) } +// GetJWTClaims mocks base method. +func (m *MockAccessRequester) GetJWTClaims() map[string]interface{} { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetJWTClaims") + ret0, _ := ret[0].(map[string]interface{}) + return ret0 +} + +// GetJWTClaims indicates an expected call of GetJWTClaims. +func (mr *MockAccessRequesterMockRecorder) GetJWTClaims() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetJWTClaims", reflect.TypeOf((*MockAccessRequester)(nil).GetJWTClaims)) +} + // GetRequestForm mocks base method. func (m *MockAccessRequester) GetRequestForm() url.Values { m.ctrl.T.Helper() @@ -253,6 +267,18 @@ func (mr *MockAccessRequesterMockRecorder) SetID(arg0 interface{}) *gomock.Call return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetID", reflect.TypeOf((*MockAccessRequester)(nil).SetID), arg0) } +// SetJWTClaims mocks base method. +func (m *MockAccessRequester) SetJWTClaims(arg0 map[string]interface{}) { + m.ctrl.T.Helper() + m.ctrl.Call(m, "SetJWTClaims", arg0) +} + +// SetJWTClaims indicates an expected call of SetJWTClaims. +func (mr *MockAccessRequesterMockRecorder) SetJWTClaims(arg0 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetJWTClaims", reflect.TypeOf((*MockAccessRequester)(nil).SetJWTClaims), arg0) +} + // SetRequestedAudience mocks base method. func (m *MockAccessRequester) SetRequestedAudience(arg0 fosite.Arguments) { m.ctrl.T.Helper() diff --git a/oauth2.go b/oauth2.go index 0827b8ed6..5283d897f 100644 --- a/oauth2.go +++ b/oauth2.go @@ -250,6 +250,12 @@ type AccessRequester interface { // GetGrantType returns the requests grant type. GetGrantTypes() (grantTypes Arguments) + // GetJWTClaims returns the request's decoded JWT claims (RFC 7523). + GetJWTClaims() (claims map[string]interface{}) + + // SetJWTClaims stores decoded JWT claims for the request (RFC 7523). + SetJWTClaims(claims map[string]interface{}) + Requester }