diff --git a/persistence/sql/persister_oauth2.go b/persistence/sql/persister_oauth2.go index 083e67ac5da..2585a3a159d 100644 --- a/persistence/sql/persister_oauth2.go +++ b/persistence/sql/persister_oauth2.go @@ -118,8 +118,8 @@ func (p *Persister) sqlSchemaFromRequest(ctx context.Context, signature string, RequestedAt: r.GetRequestedAt(), InternalExpiresAt: sqlxx.NullTime(expiresAt), Client: r.GetClient().GetID(), - Scopes: strings.Join(r.GetRequestedScopes(), "|"), - GrantedScope: strings.Join(r.GetGrantedScopes(), "|"), + Scopes: strings.Join(escapeDelimiter(r.GetRequestedScopes()), "|"), + GrantedScope: strings.Join(escapeDelimiter(r.GetGrantedScopes()), "|"), GrantedAudience: strings.Join(r.GetGrantedAudience(), "|"), RequestedAudience: strings.Join(r.GetRequestedAudience(), "|"), Form: r.GetRequestForm().Encode(), @@ -179,13 +179,23 @@ func (r *OAuth2RequestSQL) toRequest(ctx context.Context, session fosite.Session return nil, errorsx.WithStack(err) } + scopes, err := unescapeDelimiter(r.Scopes) + if err != nil { + return nil, errorsx.WithStack(err) + } + + grantedScopes, err := unescapeDelimiter(r.GrantedScope) + if err != nil { + return nil, errorsx.WithStack(err) + } + return &fosite.Request{ ID: r.Request, RequestedAt: r.RequestedAt, // ExpiresAt does not need to be populated as we get the expiry time from the session. Client: c, - RequestedScope: stringsx.Splitx(r.Scopes, "|"), - GrantedScope: stringsx.Splitx(r.GrantedScope, "|"), + RequestedScope: scopes, + GrantedScope: grantedScopes, RequestedAudience: stringsx.Splitx(r.RequestedAudience, "|"), GrantedAudience: stringsx.Splitx(r.GrantedAudience, "|"), Form: val, @@ -612,3 +622,29 @@ func (p *Persister) DeleteAccessTokens(ctx context.Context, clientID string) (er p.QueryWithNetwork(ctx).Where("client_id=?", clientID).Delete(&OAuth2RequestSQL{Table: sqlTableAccess}), ) } + +func escapeDelimiter(scopes []string) []string { + escapedScopes := make([]string, len(scopes)) + for i, scope := range scopes { + if strings.Contains(scope, "|") { + escapedScopes[i] = url.QueryEscape(scope) + } else { + escapedScopes[i] = scope + } + } + return escapedScopes +} + +func unescapeDelimiter(scopes string) ([]string, error) { + updatedScopes := stringsx.Splitx(scopes, "|") + if strings.Contains(scopes, "%26") { + for i, scope := range updatedScopes { + unescapedScope, err := url.QueryUnescape(scope) + if err != nil { + return nil, errors.Errorf("Error while url unescaping scope: %s", scope) + } + updatedScopes[i] = unescapedScope + } + } + return updatedScopes, nil +}