Skip to content

Commit

Permalink
feat: provide login hints when registration fails due to duplicate cr…
Browse files Browse the repository at this point in the history
…edentials/addresses (#3430)

* feat: provide login hints when registration fails due to duplicate credentials or identifiers

* feat: identify edge cases and write tests

* chore: synchronize workspaces

* feat: make login hints configurable

* chore: synchronize workspaces

* chore: synchronize workspaces

* chore: synchronize workspaces

* chore: synchronize workspaces

---------

Co-authored-by: aeneasr <[email protected]>
  • Loading branch information
alnr and aeneasr authored Aug 15, 2023
1 parent 9f5982b commit 8b28469
Show file tree
Hide file tree
Showing 19 changed files with 449 additions and 29 deletions.
1 change: 1 addition & 0 deletions cmd/clidoc/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,7 @@ func init() {
"NewErrorValidationPasswordPolicyViolation": text.NewErrorValidationPasswordPolicyViolation("{reason}"),
"NewErrorValidationInvalidCredentials": text.NewErrorValidationInvalidCredentials(),
"NewErrorValidationDuplicateCredentials": text.NewErrorValidationDuplicateCredentials(),
"NewErrorValidationDuplicateCredentialsWithHints": text.NewErrorValidationDuplicateCredentialsWithHints("{reason}", nil, nil, ""),
"NewErrorValidationDuplicateCredentialsOnOIDCLink": text.NewErrorValidationDuplicateCredentialsOnOIDCLink(),
"NewErrorValidationTOTPVerifierWrong": text.NewErrorValidationTOTPVerifierWrong(),
"NewErrorValidationLookupAlreadyUsed": text.NewErrorValidationLookupAlreadyUsed(),
Expand Down
5 changes: 5 additions & 0 deletions driver/config/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,7 @@ const (
ViperKeySelfServiceBrowserDefaultReturnTo = "selfservice." + DefaultBrowserReturnURL
ViperKeyURLsAllowedReturnToDomains = "selfservice.allowed_return_urls"
ViperKeySelfServiceRegistrationEnabled = "selfservice.flows.registration.enabled"
ViperKeySelfServiceRegistrationLoginHints = "selfservice.flows.registration.login_hints"
ViperKeySelfServiceRegistrationUI = "selfservice.flows.registration.ui_url"
ViperKeySelfServiceRegistrationRequestLifespan = "selfservice.flows.registration.lifespan"
ViperKeySelfServiceRegistrationAfter = "selfservice.flows.registration.after"
Expand Down Expand Up @@ -628,6 +629,10 @@ func (p *Config) SelfServiceFlowRegistrationEnabled(ctx context.Context) bool {
return p.GetProvider(ctx).Bool(ViperKeySelfServiceRegistrationEnabled)
}

func (p *Config) SelfServiceFlowRegistrationLoginHints(ctx context.Context) bool {
return p.GetProvider(ctx).Bool(ViperKeySelfServiceRegistrationLoginHints)
}

func (p *Config) SelfServiceFlowVerificationEnabled(ctx context.Context) bool {
return p.GetProvider(ctx).Bool(ViperKeySelfServiceVerificationEnabled)
}
Expand Down
6 changes: 6 additions & 0 deletions embedx/config.schema.json
Original file line number Diff line number Diff line change
Expand Up @@ -1194,6 +1194,12 @@
"description": "If set to true will enable [User Registration](https://www.ory.sh/kratos/docs/self-service/flows/user-registration/).",
"default": true
},
"login_hints": {
"type": "boolean",
"title": "Provide Login Hints on Failed Registration",
"description": "When registration fails because an account with the given credentials or addresses previously signed up, provide login hints about available methods to sign in to the user.",
"default": false
},
"ui_url": {
"title": "Registration UI URL",
"description": "URL where the Registration UI is hosted. Check the [reference implementation](https://github.com/ory/kratos-selfservice-ui-node).",
Expand Down
3 changes: 2 additions & 1 deletion identity/handler_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1469,6 +1469,7 @@ func TestHandler(t *testing.T) {
require.NoError(t, reg.PrivilegedIdentityPool().CreateIdentities(context.Background(), toCreate...))

for _, perPage := range []int{10, 50, 100, 500} {
perPage := perPage
t.Run(fmt.Sprintf("perPage=%d", perPage), func(t *testing.T) {
t.Parallel()
body, res := getFull(t, ts, fmt.Sprintf("/identities?per_page=%d", perPage), http.StatusOK)
Expand Down Expand Up @@ -1524,7 +1525,7 @@ func TestHandler(t *testing.T) {
assert.Equal(t, count/perPage, pages)
})

t.Run("using üage pagination", func(t *testing.T) {
t.Run("using page pagination", func(t *testing.T) {
knownIDs := make(map[string]struct{})
var isLast bool
var pages int
Expand Down
171 changes: 171 additions & 0 deletions identity/manager.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,15 @@ package identity

import (
"context"
"encoding/json"
"reflect"
"sort"

"go.opentelemetry.io/otel/trace"

"github.com/ory/kratos/schema"
"github.com/ory/kratos/x/events"
"github.com/ory/x/sqlcon"

"github.com/ory/x/otelx"

Expand Down Expand Up @@ -41,6 +45,7 @@ type (
courier.Provider
ValidationProvider
ActiveCredentialsCounterStrategyProvider
x.LoggingProvider
}
ManagementProvider interface {
IdentityManager() *Manager
Expand Down Expand Up @@ -95,13 +100,179 @@ func (m *Manager) Create(ctx context.Context, i *Identity, opts ...ManagerOption
}

if err := m.r.PrivilegedIdentityPool().CreateIdentity(ctx, i); err != nil {
if errors.Is(err, sqlcon.ErrUniqueViolation) {
return m.findExistingAuthMethod(ctx, err, i)
}
return err
}

trace.SpanFromContext(ctx).AddEvent(events.NewIdentityCreated(ctx, i.ID))
return nil
}

func (m *Manager) findExistingAuthMethod(ctx context.Context, e error, i *Identity) (err error) {
if !m.r.Config().SelfServiceFlowRegistrationLoginHints(ctx) {
return &ErrDuplicateCredentials{error: e}
}
// First we try to find the conflict in the identifiers table. This is most likely to have a conflict.
var found *Identity
for ct, cred := range i.Credentials {
for _, id := range cred.Identifiers {
found, _, err = m.r.PrivilegedIdentityPool().FindByCredentialsIdentifier(ctx, ct, id)
if err != nil {
continue
}

// FindByCredentialsIdentifier does not expand identity credentials.
if err = m.r.PrivilegedIdentityPool().HydrateIdentityAssociations(ctx, found, ExpandCredentials); err != nil {
return err
}
}
}

// If the conflict is not in the identifiers table, it is coming from the verifiable or recovery address.
var foundConflictAddress string
if found == nil {
for _, va := range i.VerifiableAddresses {
conflictingAddress, err := m.r.PrivilegedIdentityPool().FindVerifiableAddressByValue(ctx, va.Via, va.Value)
if errors.Is(err, sqlcon.ErrNoRows) {
continue
} else if err != nil {
return err
}

foundConflictAddress = conflictingAddress.Value
found, err = m.r.PrivilegedIdentityPool().GetIdentity(ctx, conflictingAddress.IdentityID, ExpandCredentials)
if err != nil {
return err
}
}
}

// Last option: check the recovery address
if found == nil {
for _, va := range i.RecoveryAddresses {
conflictingAddress, err := m.r.PrivilegedIdentityPool().FindRecoveryAddressByValue(ctx, va.Via, va.Value)
if errors.Is(err, sqlcon.ErrNoRows) {
continue
} else if err != nil {
return err
}

foundConflictAddress = conflictingAddress.Value
found, err = m.r.PrivilegedIdentityPool().GetIdentity(ctx, conflictingAddress.IdentityID, ExpandCredentials)
if err != nil {
return err
}
}
}

// Still not found? Return generic error.
if found == nil {
return &ErrDuplicateCredentials{error: e}
}

// We need to sort the credentials for the error message to be deterministic.
var creds []Credentials
for _, cred := range found.Credentials {
creds = append(creds, cred)
}
sort.Slice(creds, func(i, j int) bool {
return creds[i].Type < creds[j].Type
})

for _, cred := range creds {
if cred.Config == nil {
continue
}

// Basically, we only have password, oidc, and webauthn as first factor credentials.
// We don't care about second factor, because they don't help the user understand how to sign
// in to the first factor (obviously).
switch cred.Type {
case CredentialsTypePassword:
identifierHint := foundConflictAddress
if len(cred.Identifiers) > 0 {
identifierHint = cred.Identifiers[0]
}
return &ErrDuplicateCredentials{
error: e,
availableCredentials: []CredentialsType{cred.Type},
identifierHint: identifierHint,
}
case CredentialsTypeOIDC:
var cfg CredentialsOIDC
if err := json.Unmarshal(cred.Config, &cfg); err != nil {
return errors.WithStack(herodot.ErrInternalServerError.WithReasonf("Unable to JSON decode identity credentials %s for identity %s.", cred.Type, found.ID))
}

available := make([]string, 0, len(cfg.Providers))
for _, provider := range cfg.Providers {
available = append(available, provider.Provider)
}

return &ErrDuplicateCredentials{
error: e,
availableCredentials: []CredentialsType{cred.Type},
availableOIDCProviders: available,
identifierHint: foundConflictAddress,
}
case CredentialsTypeWebAuthn:
var cfg CredentialsWebAuthnConfig
if err := json.Unmarshal(cred.Config, &cfg); err != nil {
return errors.WithStack(herodot.ErrInternalServerError.WithReasonf("Unable to JSON decode identity credentials %s for identity %s.", cred.Type, found.ID))
}

identifierHint := foundConflictAddress
if len(cred.Identifiers) > 0 {
identifierHint = cred.Identifiers[0]
}

for _, webauthn := range cfg.Credentials {
if webauthn.IsPasswordless {
return &ErrDuplicateCredentials{
error: e,
availableCredentials: []CredentialsType{cred.Type},
identifierHint: identifierHint,
}
}
}
}
}

// Still not found? Return generic error.
return &ErrDuplicateCredentials{error: e}
}

type ErrDuplicateCredentials struct {
error

availableCredentials []CredentialsType
availableOIDCProviders []string
identifierHint string
}

var _ schema.DuplicateCredentialsHinter = (*ErrDuplicateCredentials)(nil)

func (e ErrDuplicateCredentials) AvailableCredentials() []string {
res := make([]string, len(e.availableCredentials))
for k, v := range e.availableCredentials {
res[k] = string(v)
}
return res
}

func (e ErrDuplicateCredentials) AvailableOIDCProviders() []string {
return e.availableOIDCProviders
}

func (e ErrDuplicateCredentials) IdentifierHint() string {
return e.identifierHint
}
func (e ErrDuplicateCredentials) HasHints() bool {
return len(e.availableCredentials) > 0 || len(e.availableOIDCProviders) > 0 || len(e.identifierHint) > 0
}

func (m *Manager) CreateIdentities(ctx context.Context, identities []*Identity, opts ...ManagerOption) (err error) {
ctx, span := m.r.Tracer(ctx).Tracer().Start(ctx, "identity.Manager.CreateIdentities")
defer otelx.End(span, &err)
Expand Down
Loading

0 comments on commit 8b28469

Please sign in to comment.