Skip to content

Commit

Permalink
fix: reduce db lookups in whoami for aal check (#3372)
Browse files Browse the repository at this point in the history
Significantly improves performance by reducing the amount of queries we need to do when checking for the different AAL levels.
  • Loading branch information
aeneasr committed Jul 13, 2023
1 parent 86ad5e1 commit d814a48
Show file tree
Hide file tree
Showing 19 changed files with 371 additions and 60 deletions.
2 changes: 1 addition & 1 deletion Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -191,7 +191,7 @@ migrations-sync: .bin/ory

.PHONY: test-update-snapshots
test-update-snapshots:
UPDATE_SNAPSHOTS=true go test -p 4 -tags sqlite -short ./...
UPDATE_SNAPSHOTS=true go test -tags sqlite,json1,refresh -short ./...

.PHONY: post-release
post-release: .bin/yq
Expand Down
39 changes: 39 additions & 0 deletions identity/credentials.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ package identity

import (
"context"
"database/sql"
"reflect"
"time"

Expand Down Expand Up @@ -33,6 +34,44 @@ const (
AuthenticatorAssuranceLevel2 AuthenticatorAssuranceLevel = "aal2"
)

type NullableAuthenticatorAssuranceLevel struct {
sql.NullString
}

// NewNullableAuthenticatorAssuranceLevel returns a new NullableAuthenticatorAssuranceLevel
func NewNullableAuthenticatorAssuranceLevel(aal AuthenticatorAssuranceLevel) NullableAuthenticatorAssuranceLevel {
switch aal {
case NoAuthenticatorAssuranceLevel:
fallthrough
case AuthenticatorAssuranceLevel1:
fallthrough
case AuthenticatorAssuranceLevel2:
return NullableAuthenticatorAssuranceLevel{sql.NullString{
String: string(aal),
Valid: true,
}}
default:
return NullableAuthenticatorAssuranceLevel{sql.NullString{}}
}
}

// ToAAL returns the AuthenticatorAssuranceLevel value of the given NullableAuthenticatorAssuranceLevel.
func (n NullableAuthenticatorAssuranceLevel) ToAAL() (AuthenticatorAssuranceLevel, bool) {
if !n.Valid {
return "", false
}
switch n.String {
case string(NoAuthenticatorAssuranceLevel):
return NoAuthenticatorAssuranceLevel, true
case string(AuthenticatorAssuranceLevel1):
return AuthenticatorAssuranceLevel1, true
case string(AuthenticatorAssuranceLevel2):
return AuthenticatorAssuranceLevel2, true
default:
return "", false
}
}

// CredentialsType represents several different credential types, like password credentials, passwordless credentials,
// and so on.
//
Expand Down
25 changes: 25 additions & 0 deletions identity/identity.go
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,10 @@ type Identity struct {
// Credentials represents all credentials that can be used for authenticating this identity.
Credentials map[CredentialsType]Credentials `json:"credentials,omitempty" faker:"-" db:"-"`

// AvailableAAL defines the maximum available AAL for this identity. If the user has only a password
// configured, the AAL will be 1. If the user has a password and a TOTP configured, the AAL will be 2.
AvailableAAL NullableAuthenticatorAssuranceLevel `json:"-" faker:"-" db:"available_aal"`

// // IdentifierCredentials contains the access and refresh token for oidc identifier
// IdentifierCredentials []IdentifierCredential `json:"identifier_credentials,omitempty" faker:"-" db:"-"`

Expand Down Expand Up @@ -317,6 +321,27 @@ func (i *Identity) UnmarshalJSON(b []byte) error {
return err
}

func (i *Identity) SetAvailableAAL(ctx context.Context, m *Manager) (err error) {
i.AvailableAAL = NewNullableAuthenticatorAssuranceLevel(NoAuthenticatorAssuranceLevel)
if c, err := m.CountActiveFirstFactorCredentials(ctx, i); err != nil {
return err
} else if c == 0 {
// No first factor set up - AAL is 0
return nil
}

i.AvailableAAL = NewNullableAuthenticatorAssuranceLevel(AuthenticatorAssuranceLevel1)
if c, err := m.CountActiveMultiFactorCredentials(ctx, i); err != nil {
return err
} else if c == 0 {
// No second factor set up - AAL is 1
return nil
}

i.AvailableAAL = NewNullableAuthenticatorAssuranceLevel(AuthenticatorAssuranceLevel2)
return nil
}

type WithAdminMetadataInJSON Identity

func (i WithAdminMetadataInJSON) MarshalJSON() ([]byte, error) {
Expand Down
12 changes: 12 additions & 0 deletions identity/manager.go
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,10 @@ func (m *Manager) Create(ctx context.Context, i *Identity, opts ...ManagerOption
return err
}

if err := i.SetAvailableAAL(ctx, m); err != nil {
return err
}

if err := m.r.PrivilegedIdentityPool().CreateIdentity(ctx, i); err != nil {
return err
}
Expand All @@ -107,6 +111,10 @@ func (m *Manager) CreateIdentities(ctx context.Context, identities []*Identity,
i.SchemaID = m.r.Config().DefaultIdentityTraitsSchemaID(ctx)
}

if err := i.SetAvailableAAL(ctx, m); err != nil {
return err
}

o := newManagerOptions(opts)
if err := m.ValidateIdentity(ctx, i, o); err != nil {
return err
Expand Down Expand Up @@ -164,6 +172,10 @@ func (m *Manager) Update(ctx context.Context, updated *Identity, opts ...Manager
return err
}

if err := updated.SetAvailableAAL(ctx, m); err != nil {
return err
}

return m.r.PrivilegedIdentityPool().UpdateIdentity(ctx, updated)
}

Expand Down
135 changes: 133 additions & 2 deletions identity/manager_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@ import (
"testing"
"time"

"github.com/gofrs/uuid"

"github.com/ory/x/sqlxx"

"github.com/ory/kratos/internal/testhelpers"
Expand Down Expand Up @@ -65,10 +67,82 @@ func TestManager(t *testing.T) {

t.Run("method=Create", func(t *testing.T) {
t.Run("case=should create identity and track extension fields", func(t *testing.T) {
email := uuid.Must(uuid.NewV4()).String() + "@ory.sh"
original := identity.NewIdentity(config.DefaultIdentityTraitsSchemaID)
original.Traits = newTraits("[email protected]", "")
original.Traits = newTraits(email, "")
require.NoError(t, reg.IdentityManager().Create(context.Background(), original))
checkExtensionFieldsForIdentities(t, "[email protected]", original)
checkExtensionFieldsForIdentities(t, email, original)
got, ok := original.AvailableAAL.ToAAL()
require.True(t, ok)
assert.Equal(t, identity.NoAuthenticatorAssuranceLevel, got)
})

t.Run("case=correctly set AAL", func(t *testing.T) {
t.Run("case=should set AAL to 0 if no credentials are available", func(t *testing.T) {
email := uuid.Must(uuid.NewV4()).String() + "@ory.sh"
original := identity.NewIdentity(config.DefaultIdentityTraitsSchemaID)
original.Traits = newTraits(email, "")
require.NoError(t, reg.IdentityManager().Create(context.Background(), original))
got, ok := original.AvailableAAL.ToAAL()
require.True(t, ok)
assert.Equal(t, identity.NoAuthenticatorAssuranceLevel, got)
})

t.Run("case=should set AAL to 1 if password is set", func(t *testing.T) {
email := uuid.Must(uuid.NewV4()).String() + "@ory.sh"
original := identity.NewIdentity(config.DefaultIdentityTraitsSchemaID)
original.Traits = newTraits(email, "")
original.Credentials = map[identity.CredentialsType]identity.Credentials{
identity.CredentialsTypePassword: {
Type: identity.CredentialsTypePassword,
Identifiers: []string{email},
Config: sqlxx.JSONRawMessage(`{"hashed_password":"$2a$08$.cOYmAd.vCpDOoiVJrO5B.hjTLKQQ6cAK40u8uB.FnZDyPvVvQ9Q."}`),
},
}
require.NoError(t, reg.IdentityManager().Create(context.Background(), original))
got, ok := original.AvailableAAL.ToAAL()
require.True(t, ok)
assert.Equal(t, identity.AuthenticatorAssuranceLevel1, got)
})

t.Run("case=should set AAL to 2 if password and TOTP is set", func(t *testing.T) {
email := uuid.Must(uuid.NewV4()).String() + "@ory.sh"
original := identity.NewIdentity(config.DefaultIdentityTraitsSchemaID)
original.Traits = newTraits(email, "")
original.Credentials = map[identity.CredentialsType]identity.Credentials{
identity.CredentialsTypePassword: {
Type: identity.CredentialsTypePassword,
Identifiers: []string{email},
Config: sqlxx.JSONRawMessage(`{"hashed_password":"$2a$08$.cOYmAd.vCpDOoiVJrO5B.hjTLKQQ6cAK40u8uB.FnZDyPvVvQ9Q."}`),
},
identity.CredentialsTypeTOTP: {
Type: identity.CredentialsTypeTOTP,
Identifiers: []string{email},
Config: sqlxx.JSONRawMessage(`{"totp_url":"otpauth://totp/test"}`),
},
}
require.NoError(t, reg.IdentityManager().Create(context.Background(), original))
got, ok := original.AvailableAAL.ToAAL()
require.True(t, ok)
assert.Equal(t, identity.AuthenticatorAssuranceLevel2, got)
})

t.Run("case=should set AAL to 0 if only TOTP is set", func(t *testing.T) {
email := uuid.Must(uuid.NewV4()).String() + "@ory.sh"
original := identity.NewIdentity(config.DefaultIdentityTraitsSchemaID)
original.Traits = newTraits(email, "")
original.Credentials = map[identity.CredentialsType]identity.Credentials{
identity.CredentialsTypeTOTP: {
Type: identity.CredentialsTypeTOTP,
Identifiers: []string{email},
Config: sqlxx.JSONRawMessage(`{"totp_url":"otpauth://totp/test"}`),
},
}
require.NoError(t, reg.IdentityManager().Create(context.Background(), original))
got, ok := original.AvailableAAL.ToAAL()
require.True(t, ok)
assert.Equal(t, identity.NoAuthenticatorAssuranceLevel, got)
})
})

t.Run("case=should expose validation errors with option", func(t *testing.T) {
Expand Down Expand Up @@ -100,6 +174,63 @@ func TestManager(t *testing.T) {
checkExtensionFieldsForIdentities(t, "[email protected]", original)
})

t.Run("case=should set AAL to 1 if password is set", func(t *testing.T) {
email := uuid.Must(uuid.NewV4()).String() + "@ory.sh"
original := identity.NewIdentity(config.DefaultIdentityTraitsSchemaID)
original.Traits = newTraits(email, "")
require.NoError(t, reg.IdentityManager().Create(context.Background(), original))
original.Credentials = map[identity.CredentialsType]identity.Credentials{
identity.CredentialsTypePassword: {
Type: identity.CredentialsTypePassword,
Identifiers: []string{email},
Config: sqlxx.JSONRawMessage(`{"hashed_password":"$2a$08$.cOYmAd.vCpDOoiVJrO5B.hjTLKQQ6cAK40u8uB.FnZDyPvVvQ9Q."}`),
},
}
require.NoError(t, reg.IdentityManager().Update(context.Background(), original, identity.ManagerAllowWriteProtectedTraits))
assert.EqualValues(t, identity.AuthenticatorAssuranceLevel1, original.AvailableAAL.String)
})

t.Run("case=should set AAL to 2 if password and TOTP is set", func(t *testing.T) {
email := uuid.Must(uuid.NewV4()).String() + "@ory.sh"
original := identity.NewIdentity(config.DefaultIdentityTraitsSchemaID)
original.Traits = newTraits(email, "")
original.Credentials = map[identity.CredentialsType]identity.Credentials{
identity.CredentialsTypePassword: {
Type: identity.CredentialsTypePassword,
Identifiers: []string{email},
Config: sqlxx.JSONRawMessage(`{"hashed_password":"$2a$08$.cOYmAd.vCpDOoiVJrO5B.hjTLKQQ6cAK40u8uB.FnZDyPvVvQ9Q."}`),
},
}
require.NoError(t, reg.IdentityManager().Create(context.Background(), original))
assert.EqualValues(t, identity.AuthenticatorAssuranceLevel1, original.AvailableAAL.String)
require.NoError(t, reg.IdentityManager().Update(context.Background(), original, identity.ManagerAllowWriteProtectedTraits))
assert.EqualValues(t, identity.AuthenticatorAssuranceLevel1, original.AvailableAAL.String, "Updating without changes should not change AAL")
original.Credentials[identity.CredentialsTypeTOTP] = identity.Credentials{
Type: identity.CredentialsTypeTOTP,
Identifiers: []string{email},
Config: sqlxx.JSONRawMessage(`{"totp_url":"otpauth://totp/test"}`),
}
require.NoError(t, reg.IdentityManager().Update(context.Background(), original, identity.ManagerAllowWriteProtectedTraits))
assert.EqualValues(t, identity.AuthenticatorAssuranceLevel2, original.AvailableAAL.String)
})

t.Run("case=should set AAL to 0 if only TOTP is set", func(t *testing.T) {
email := uuid.Must(uuid.NewV4()).String() + "@ory.sh"
original := identity.NewIdentity(config.DefaultIdentityTraitsSchemaID)
original.Traits = newTraits(email, "")
require.NoError(t, reg.IdentityManager().Create(context.Background(), original))
original.Credentials = map[identity.CredentialsType]identity.Credentials{
identity.CredentialsTypeTOTP: {
Type: identity.CredentialsTypeTOTP,
Identifiers: []string{email},
Config: sqlxx.JSONRawMessage(`{"totp_url":"otpauth://totp/test"}`),
},
}
require.NoError(t, reg.IdentityManager().Update(context.Background(), original, identity.ManagerAllowWriteProtectedTraits))
assert.True(t, original.AvailableAAL.Valid)
assert.EqualValues(t, identity.NoAuthenticatorAssuranceLevel, original.AvailableAAL.String)
})

t.Run("case=should not update protected traits without option", func(t *testing.T) {
original := identity.NewIdentity(config.DefaultIdentityTraitsSchemaID)
original.Traits = newTraits("[email protected]", "")
Expand Down
12 changes: 12 additions & 0 deletions identity/test/pool.go
Original file line number Diff line number Diff line change
Expand Up @@ -320,6 +320,18 @@ func TestPool(ctx context.Context, conf *config.Config, p persistence.Persister,
})
})

t.Run("case=create with null AAL", func(t *testing.T) {
expected := passwordIdentity("", "id-"+uuid.Must(uuid.NewV4()).String())
expected.AvailableAAL.Valid = false
require.NoError(t, p.CreateIdentity(ctx, expected))
createdIDs = append(createdIDs, expected.ID)

actual, err := p.GetIdentity(ctx, expected.ID, identity.ExpandDefault)
require.NoError(t, err)

assert.False(t, actual.AvailableAAL.Valid)
})

t.Run("suite=create multiple identities", func(t *testing.T) {
t.Run("create multiple identities", func(t *testing.T) {
identities := make([]*identity.Identity, 100)
Expand Down
3 changes: 2 additions & 1 deletion internal/testhelpers/session.go
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ func NewSessionClient(t *testing.T, u string) *http.Client {
func maybePersistSession(t *testing.T, reg *driver.RegistryDefault, sess *session.Session) {
id, err := reg.PrivilegedIdentityPool().GetIdentityConfidential(context.Background(), sess.Identity.ID)
if err != nil {
require.NoError(t, sess.Identity.SetAvailableAAL(context.Background(), reg.IdentityManager()))
require.NoError(t, reg.PrivilegedIdentityPool().CreateIdentity(context.Background(), sess.Identity))
id, err = reg.PrivilegedIdentityPool().GetIdentityConfidential(context.Background(), sess.Identity.ID)
require.NoError(t, err)
Expand Down Expand Up @@ -156,7 +157,7 @@ func NewHTTPClientWithArbitrarySessionToken(t *testing.T, reg *driver.RegistryDe
func NewHTTPClientWithArbitrarySessionCookie(t *testing.T, reg *driver.RegistryDefault) *http.Client {
req := x.NewTestHTTPRequest(t, "GET", "/sessions/whoami", nil)
s, err := session.NewActiveSession(req,
&identity.Identity{ID: x.NewUUID(), State: identity.StateActive},
&identity.Identity{ID: x.NewUUID(), State: identity.StateActive, Traits: []byte("{}")},
NewSessionLifespanProvider(time.Hour),
time.Now(),
identity.CredentialsTypePassword,
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
{
"TableName": "\"identities\"",
"ColumnsDecl": "\"created_at\", \"id\", \"metadata_admin\", \"metadata_public\", \"nid\", \"schema_id\", \"state\", \"state_changed_at\", \"traits\", \"updated_at\"",
"ColumnsDecl": "\"available_aal\", \"created_at\", \"id\", \"metadata_admin\", \"metadata_public\", \"nid\", \"schema_id\", \"state\", \"state_changed_at\", \"traits\", \"updated_at\"",
"Columns": [
"available_aal",
"created_at",
"id",
"metadata_admin",
Expand All @@ -13,5 +14,5 @@
"traits",
"updated_at"
],
"Placeholders": "(?, ?, ?, ?, ?, ?, ?, ?, ?, ?),\n(?, ?, ?, ?, ?, ?, ?, ?, ?, ?),\n(?, ?, ?, ?, ?, ?, ?, ?, ?, ?),\n(?, ?, ?, ?, ?, ?, ?, ?, ?, ?),\n(?, ?, ?, ?, ?, ?, ?, ?, ?, ?),\n(?, ?, ?, ?, ?, ?, ?, ?, ?, ?),\n(?, ?, ?, ?, ?, ?, ?, ?, ?, ?),\n(?, ?, ?, ?, ?, ?, ?, ?, ?, ?),\n(?, ?, ?, ?, ?, ?, ?, ?, ?, ?),\n(?, ?, ?, ?, ?, ?, ?, ?, ?, ?)"
"Placeholders": "(?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?),\n(?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?),\n(?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?),\n(?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?),\n(?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?),\n(?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?),\n(?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?),\n(?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?),\n(?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?),\n(?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)"
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
{
"id": "0149ce5f-76a8-4efe-b2e3-431b8c6cceb6",
"schema_id": "default",
"schema_url": "https://www.ory.sh/schemas/ZGVmYXVsdA",
"state": "active",
"traits": {
"email": "[email protected]"
},
"metadata_public": {
"foo": "bar"
},
"metadata_admin": {
"baz": "bar"
},
"created_at": "2013-10-07T08:23:19Z",
"updated_at": "2013-10-07T08:23:19Z"
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
{
"id": "0149ce5f-76a8-4efe-b2e3-431b8c6cceb7",
"schema_id": "default",
"schema_url": "https://www.ory.sh/schemas/ZGVmYXVsdA",
"state": "active",
"traits": {
"email": "[email protected]"
},
"metadata_public": {
"foo": "bar"
},
"metadata_admin": {
"baz": "bar"
},
"created_at": "2013-10-07T08:23:19Z",
"updated_at": "2013-10-07T08:23:19Z"
}
11 changes: 11 additions & 0 deletions persistence/sql/migratest/testdata/20230706000000_testdata.sql
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
INSERT INTO identities (id, nid, schema_id, traits, created_at, updated_at, metadata_public, metadata_admin,
available_aal)
VALUES ('0149ce5f-76a8-4efe-b2e3-431b8c6cceb6', '884f556e-eb3a-4b9f-bee3-11345642c6c0', 'default',
'{"email":"[email protected]"}', '2013-10-07 08:23:19', '2013-10-07 08:23:19', '{"foo":"bar"}', '{"baz":"bar"}',
'aal1');

INSERT INTO identities (id, nid, schema_id, traits, created_at, updated_at, metadata_public, metadata_admin,
available_aal)
VALUES ('0149ce5f-76a8-4efe-b2e3-431b8c6cceb7', '884f556e-eb3a-4b9f-bee3-11345642c6c0', 'default',
'{"email":"[email protected]"}', '2013-10-07 08:23:19', '2013-10-07 08:23:19', '{"foo":"bar"}', '{"baz":"bar"}',
NULL);
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
ALTER TABLE identities DROP COLUMN available_aal;
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
ALTER TABLE identities ADD COLUMN available_aal VARCHAR(4) NULL;
Loading

0 comments on commit d814a48

Please sign in to comment.