Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: change primary keys in tables for identity-associated data #3777

Open
wants to merge 5 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
62 changes: 33 additions & 29 deletions driver/registry.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,45 +7,39 @@ import (
"context"
"io/fs"

"github.com/ory/kratos/selfservice/sessiontokenexchange"
"github.com/ory/x/contextx"
"github.com/ory/x/jsonnetsecure"
"github.com/ory/x/otelx"
prometheus "github.com/ory/x/prometheusx"

"github.com/gorilla/sessions"
"github.com/pkg/errors"

"github.com/ory/nosurf"

"github.com/ory/x/logrusx"

"github.com/ory/kratos/cipher"
"github.com/ory/kratos/continuity"
"github.com/ory/kratos/courier"
"github.com/ory/kratos/driver/config"
"github.com/ory/kratos/hash"
"github.com/ory/kratos/identity"
"github.com/ory/kratos/persistence"
"github.com/ory/kratos/schema"
"github.com/ory/kratos/selfservice/errorx"
"github.com/ory/kratos/selfservice/flow/login"
"github.com/ory/kratos/selfservice/flow/logout"
"github.com/ory/kratos/selfservice/flow/recovery"
"github.com/ory/kratos/selfservice/flow/registration"
"github.com/ory/kratos/selfservice/flow/settings"
"github.com/ory/kratos/selfservice/flow/verification"
"github.com/ory/kratos/selfservice/sessiontokenexchange"
"github.com/ory/kratos/selfservice/strategy/code"
"github.com/ory/kratos/selfservice/strategy/link"

"github.com/ory/x/healthx"

"github.com/ory/kratos/persistence"
"github.com/ory/kratos/selfservice/flow/login"
"github.com/ory/kratos/selfservice/flow/logout"
"github.com/ory/kratos/selfservice/flow/registration"

"github.com/ory/kratos/x"

"github.com/ory/x/dbal"

"github.com/ory/kratos/driver/config"
"github.com/ory/kratos/identity"
"github.com/ory/kratos/selfservice/errorx"
password2 "github.com/ory/kratos/selfservice/strategy/password"
"github.com/ory/kratos/session"
"github.com/ory/kratos/x"
"github.com/ory/nosurf"
"github.com/ory/x/contextx"
"github.com/ory/x/dbal"
"github.com/ory/x/healthx"
"github.com/ory/x/jsonnetsecure"
"github.com/ory/x/logrusx"
"github.com/ory/x/otelx"
"github.com/ory/x/popx"
prometheus "github.com/ory/x/prometheusx"
)

type Registry interface {
Expand Down Expand Up @@ -85,6 +79,8 @@ type Registry interface {
continuity.ManagementProvider
continuity.PersistenceProvider

cipher.Provider

courier.Provider

persistence.Provider
Expand Down Expand Up @@ -186,10 +182,12 @@ type options struct {
replaceIdentitySchemaProvider func(Registry) schema.IdentitySchemaProvider
inspect func(Registry) error
extraMigrations []fs.FS
replacementStrategies []NewStrategy
extraHooks map[string]func(config.SelfServiceHook) any
disableMigrationLogging bool
jsonnetPool jsonnetsecure.Pool
extraGoMigrations popx.Migrations

replacementStrategies []NewStrategy
extraHooks map[string]func(config.SelfServiceHook) any
disableMigrationLogging bool
jsonnetPool jsonnetsecure.Pool
}

type RegistryOption func(*options)
Expand Down Expand Up @@ -251,6 +249,12 @@ func WithExtraMigrations(m ...fs.FS) RegistryOption {
}
}

func WithExtraGoMigrations(m ...popx.Migration) RegistryOption {
return func(o *options) {
o.extraGoMigrations = append(o.extraGoMigrations, m...)
}
}

func WithDisabledMigrationLogging() RegistryOption {
return func(o *options) {
o.disableMigrationLogging = true
Expand Down
5 changes: 4 additions & 1 deletion driver/registry_default.go
Original file line number Diff line number Diff line change
Expand Up @@ -672,7 +672,10 @@ func (m *RegistryDefault) Init(ctx context.Context, ctxer contextx.Contextualize
m.Logger().WithError(err).Warnf("Unable to open database, retrying.")
return errors.WithStack(err)
}
p, err := sql.NewPersister(ctx, m, c, sql.WithExtraMigrations(o.extraMigrations...), sql.WithDisabledLogging(o.disableMigrationLogging))
p, err := sql.NewPersister(ctx, m, c,
sql.WithExtraMigrations(o.extraMigrations...),
sql.WithExtraGoMigrations(o.extraGoMigrations...),
sql.WithDisabledLogging(o.disableMigrationLogging))
if err != nil {
m.Logger().WithError(err).Warnf("Unable to initialize persister, retrying.")
return err
Expand Down
2 changes: 1 addition & 1 deletion go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ require (
github.com/jmoiron/sqlx v1.4.0
github.com/julienschmidt/httprouter v1.3.0
github.com/knadh/koanf/parsers/json v0.1.0
github.com/laher/mergefs v0.1.2-0.20230223191438-d16611b2f4e7
github.com/laher/mergefs v0.1.2-0.20230223191438-d16611b2f4e7 // indirect
github.com/lestrrat-go/jwx/v2 v2.1.1
github.com/luna-duclos/instrumentedsql v1.1.3
github.com/mailhog/MailHog v1.0.1
Expand Down
1 change: 0 additions & 1 deletion go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -551,7 +551,6 @@ github.com/mailru/easyjson v0.7.7 h1:UGYAvKxe3sBsEDzO8ZeWOSlIQfWFlxbzLZe7hwFURr0
github.com/mailru/easyjson v0.7.7/go.mod h1:xzfreul335JAWq5oZzymOObrkdz5UnU4kGfJJLY9Nlc=
github.com/markbates/pkger v0.17.1 h1:/MKEtWqtc0mZvu9OinB9UzVN9iYCwLWuyUv4Bw+PCno=
github.com/markbates/pkger v0.17.1/go.mod h1:0JoVlrol20BSywW79rN3kdFFsE5xYM+rSCQDXbLhiuI=
github.com/matryer/is v1.4.0 h1:sosSmIWwkYITGrxZ25ULNDeKiMNzFSr4V/eqBQP0PeE=
github.com/matryer/is v1.4.0/go.mod h1:8I/i5uYgLzgsgEloJE1U6xx5HkBQpAZvepWuujKwMRU=
github.com/mattn/go-colorable v0.1.8/go.mod h1:u6P/XSegPjTcexA+o6vUJrdnUu04hMope9wVRipJSqc=
github.com/mattn/go-colorable v0.1.9/go.mod h1:u6P/XSegPjTcexA+o6vUJrdnUu04hMope9wVRipJSqc=
Expand Down
6 changes: 6 additions & 0 deletions identity/credentials.go
Original file line number Diff line number Diff line change
Expand Up @@ -191,11 +191,17 @@ func (c Credentials) GetID() uuid.UUID {
return c.ID
}

func (c Credentials) GetNID() uuid.UUID {
return c.NID
}

type (
// swagger:ignore
CredentialIdentifier struct {
ID uuid.UUID `db:"id"`
Identifier string `db:"identifier"`
// Identity is a helper struct field for gobuffalo.pop.
IdentityID uuid.UUID `json:"-" db:"identity_id"`
// IdentityCredentialsID is a helper struct field for gobuffalo.pop.
IdentityCredentialsID uuid.UUID `json:"-" db:"identity_credential_id"`
// IdentityCredentialsTypeID is a helper struct field for gobuffalo.pop.
Expand Down
4 changes: 2 additions & 2 deletions identity/test/pool.go
Original file line number Diff line number Diff line change
Expand Up @@ -1347,8 +1347,8 @@ func TestPool(ctx context.Context, p persistence.Persister, m *identity.Manager,
require.NoError(t, p.GetConnection(ctx).RawQuery("INSERT INTO identity_credentials (id, identity_id, nid, identity_credential_type_id, created_at, updated_at, config) VALUES (?, ?, ?, ?, ?, ?, '{}')", cid2, iid, nid2, m[0].ID, time.Now(), time.Now()).Exec())

ici1, ici2 := x.NewUUID(), x.NewUUID()
require.NoError(t, p.GetConnection(ctx).RawQuery("INSERT INTO identity_credential_identifiers (id, identity_credential_id, nid, identifier, created_at, updated_at, identity_credential_type_id) VALUES (?, ?, ?, ?, ?, ?, ?)", ici1, cid1, nid1, "nid1", time.Now(), time.Now(), m[0].ID).Exec())
require.NoError(t, p.GetConnection(ctx).RawQuery("INSERT INTO identity_credential_identifiers (id, identity_credential_id, nid, identifier, created_at, updated_at, identity_credential_type_id) VALUES (?, ?, ?, ?, ?, ?, ?)", ici2, cid2, nid2, "nid2", time.Now(), time.Now(), m[0].ID).Exec())
require.NoError(t, p.GetConnection(ctx).RawQuery("INSERT INTO identity_credential_identifiers (id, identity_id, identity_credential_id, nid, identifier, created_at, updated_at, identity_credential_type_id) VALUES (?, ?, ?, ?, ?, ?, ?, ?)", ici1, iid, cid1, nid1, "nid1", time.Now(), time.Now(), m[0].ID).Exec())
require.NoError(t, p.GetConnection(ctx).RawQuery("INSERT INTO identity_credential_identifiers (id, identity_id, identity_credential_id, nid, identifier, created_at, updated_at, identity_credential_type_id) VALUES (?, ?, ?, ?, ?, ?, ?, ?)", ici2, iid, cid2, nid2, "nid2", time.Now(), time.Now(), m[0].ID).Exec())

_, err := p.GetIdentity(ctx, nid1, identity.ExpandNothing)
require.ErrorIs(t, err, sqlcon.ErrNoRows)
Expand Down
8 changes: 5 additions & 3 deletions persistence/sql/identity/persister_identity.go
Original file line number Diff line number Diff line change
Expand Up @@ -377,6 +377,7 @@ func (p *IdentityPersister) createIdentityCredentials(ctx context.Context, conn

identifiers = append(identifiers, &identity.CredentialIdentifier{
Identifier: identifier,
IdentityID: cred.IdentityID,
IdentityCredentialsID: cred.ID,
IdentityCredentialsTypeID: ct.ID,
NID: p.NetworkID(ctx),
Expand Down Expand Up @@ -633,7 +634,6 @@ func (p *IdentityPersister) CreateIdentities(ctx context.Context, identities ...
for _, k := range paritalErr.Failed {
failedIdentityIDs[k.IdentityID] = struct{}{}
}

} else if paritalErr := new(batch.PartialConflictError[identity.CredentialIdentifier]); errors.As(err, &paritalErr) {
for _, k := range paritalErr.Failed {
credID := k.IdentityCredentialsID
Expand Down Expand Up @@ -787,7 +787,7 @@ func QueryForCredentials(con *pop.Connection, where ...Where) (map[uuid.UUID](ma
ici := "identity_credential_identifiers"
switch con.Dialect.Name() {
case "cockroach":
ici += "@identity_credential_identifiers_nid_identity_credential_id_idx"
ici += "@primary"
case "sqlite3":
ici += " INDEXED BY identity_credential_identifiers_nid_identity_credential_id_idx"
case "mysql":
Expand All @@ -811,7 +811,9 @@ func QueryForCredentials(con *pop.Connection, where ...Where) (map[uuid.UUID](ma
"(identity_credentials.identity_credential_type_id = ict.id)",
).LeftJoin(
ici,
"identity_credential_identifiers.identity_credential_id = identity_credentials.id AND identity_credential_identifiers.nid = identity_credentials.nid",
`identity_credential_identifiers.identity_id = identity_credentials.identity_id
AND identity_credential_identifiers.identity_credential_id = identity_credentials.id
AND identity_credential_identifiers.nid = identity_credentials.nid`,
)
for _, w := range where {
q = q.Where("("+w.Condition+")", w.Args...)
Expand Down
30 changes: 15 additions & 15 deletions persistence/sql/migratest/migration_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,30 +8,21 @@ import (
"encoding/json"
"os"
"path/filepath"
"slices"
"sync"
"testing"
"time"

"github.com/ory/x/pagination/keysetpagination"
"github.com/ory/x/servicelocatorx"

"github.com/ory/kratos/identity"

"github.com/bradleyjkemp/cupaloy/v2"
"github.com/stretchr/testify/assert"

"github.com/ory/x/dbal"

"github.com/ory/kratos/x/xsql"

"github.com/ory/x/migratest"

"github.com/gobuffalo/pop/v6"
"github.com/sirupsen/logrus"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"

"github.com/ory/kratos/driver"
"github.com/ory/kratos/driver/config"
"github.com/ory/kratos/identity"
"github.com/ory/kratos/persistence/sql/migrations/gomigrations"
"github.com/ory/kratos/selfservice/flow/login"
"github.com/ory/kratos/selfservice/flow/recovery"
"github.com/ory/kratos/selfservice/flow/registration"
Expand All @@ -41,9 +32,14 @@ import (
"github.com/ory/kratos/selfservice/strategy/link"
"github.com/ory/kratos/session"
"github.com/ory/kratos/x"
"github.com/ory/kratos/x/xsql"
"github.com/ory/x/configx"
"github.com/ory/x/dbal"
"github.com/ory/x/logrusx"
"github.com/ory/x/migratest"
"github.com/ory/x/pagination/keysetpagination"
"github.com/ory/x/popx"
"github.com/ory/x/servicelocatorx"
"github.com/ory/x/sqlcon"
"github.com/ory/x/sqlcon/dockertest"
)
Expand Down Expand Up @@ -87,15 +83,15 @@ func TestMigrations_Postgres(t *testing.T) {
t.Skip("skipping testing in short mode")
}
t.Parallel()
testDatabase(t, "postgres", dockertest.ConnectPop(t, dockertest.RunTestPostgreSQLWithVersion(t, "11.8")))
testDatabase(t, "postgres", dockertest.ConnectPop(t, dockertest.RunTestPostgreSQLWithVersion(t, "16")))
}

func TestMigrations_Mysql(t *testing.T) {
if testing.Short() {
t.Skip("skipping testing in short mode")
}
t.Parallel()
testDatabase(t, "mysql", dockertest.ConnectPop(t, dockertest.RunTestMySQLWithVersion(t, "8.0.34")))
testDatabase(t, "mysql", dockertest.ConnectPop(t, dockertest.RunTestMySQLWithVersion(t, "8.0")))
}

func TestMigrations_Cockroach(t *testing.T) {
Expand Down Expand Up @@ -134,6 +130,10 @@ func testDatabase(t *testing.T, db string, c *pop.Connection) {
os.DirFS("../migrations/sql"),
popx.NewMigrator(c, l, nil, 1*time.Minute),
popx.WithTestdata(t, os.DirFS("./testdata")),
popx.WithGoMigrations(slices.Concat(
gomigrations.IdentityPrimaryKeysStep1,
gomigrations.IdentityPrimaryKeysStep2,
)),
)
require.NoError(t, err)
tm.DumpMigrations = true
Expand Down
Loading
Loading