Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

chore: code review for code strategy (magic code login) #3456

Merged
merged 11 commits into from
Aug 29, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 7 additions & 1 deletion Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,13 @@ test-short:

.PHONY: test-coverage
test-coverage: .bin/go-acc .bin/goveralls
go-acc -o coverage.out ./... -- -v -failfast -timeout=20m -tags sqlite
go-acc -o coverage.out ./... -- -v -failfast -timeout=20m -tags sqlite,json1

.PHONY: test-coverage-next
test-coverage-next: .bin/go-acc .bin/goveralls
go test -short -failfast -timeout=20m -tags sqlite,json1 -cover ./... --args test.gocoverdir="$$PWD/coverage"
go tool covdata percent -i=coverage
go tool covdata textfmt -i=./coverage -o coverage.new.out

# Generates the SDK
.PHONY: sdk
Expand Down
1 change: 1 addition & 0 deletions coverage/.gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
*
2 changes: 1 addition & 1 deletion driver/registry_default_registration.go
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ func (m *RegistryDefault) PostRegistrationPostPersistHooks(ctx context.Context,
}

if len(b) == initialHookCount {
// since we don't want merging hooks defined in a specific strategy and global hooks
// since we don't want merging hooks defined in a specific strategy and
// global hooks are added only if no strategy specific hooks are defined
for _, v := range m.getHooks(config.HookGlobal, m.Config().SelfServiceFlowRegistrationAfterHooks(ctx, config.HookGlobal)) {
if hook, ok := v.(registration.PostHookPostPersistExecutor); ok {
Expand Down
6 changes: 0 additions & 6 deletions identity/credentials.go
Original file line number Diff line number Diff line change
Expand Up @@ -168,12 +168,6 @@ type Credentials struct {
// Identifiers represents a list of unique identifiers this credential type matches.
Identifiers []string `json:"identifiers" db:"-"`

// IdentifierAddressType represents the type of the identifiers (e.g. email, phone).
// This is used to determine the correct courier to send messages to.
// The value is set by the code extension schema and is not persisted.
// only applicable on the login, registration with `code` method.
IdentifierAddressType CredentialsIdentifierAddressType `json:"-" db:"-"`

// Config contains the concrete credential payload. This might contain the bcrypt-hashed password, the email
// for passwordless authentication or access_token and refresh tokens from OpenID Connect flows.
Config sqlxx.JSONRawMessage `json:"config,omitempty" db:"config"`
Expand Down
2 changes: 1 addition & 1 deletion identity/credentials_code.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ const (
CodeAddressTypePhone CodeAddressType = AddressTypePhone
)

// CredentialsCode represents a one time login/registraiton code
// CredentialsCode represents a one time login/registration code
//
// swagger:model identityCredentialsCode
type CredentialsCode struct {
Expand Down
3 changes: 1 addition & 2 deletions identity/extension_credentials.go
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,6 @@ func (r *SchemaExtensionCredentials) setIdentifier(ct CredentialsType, value int

r.v[ct] = stringslice.Unique(append(r.v[ct], strings.ToLower(fmt.Sprintf("%s", value))))
cred.Identifiers = r.v[ct]
cred.IdentifierAddressType = addressType
r.i.SetCredentials(ct, *cred)
}

Expand All @@ -64,7 +63,7 @@ func (r *SchemaExtensionCredentials) Run(ctx jsonschema.ValidationContext, s sch
return ctx.Error("format", "%q is not a valid %q", value, s.Credentials.Code.Via)
}

r.setIdentifier(CredentialsTypeCodeAuth, value, CredentialsIdentifierAddressType(AddressTypeEmail))
r.setIdentifier(CredentialsTypeCodeAuth, value, AddressTypeEmail)
// case f.AddCase(AddressTypePhone):
// if !jsonschema.Formats["tel"](value) {
// return ctx.Error("format", "%q is not a valid %q", value, s.Credentials.Code.Via)
Expand Down
15 changes: 15 additions & 0 deletions identity/extension_credentials_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,21 @@ func TestSchemaExtensionCredentials(t *testing.T) {
},
ct: identity.CredentialsTypeWebAuthn,
},
{
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These did not have tests

doc: `{"email":"[email protected]"}`,
schema: "file://./stub/extension/credentials/code.schema.json",
expect: []string{"[email protected]"},
ct: identity.CredentialsTypeCodeAuth,
},
{
doc: `{"email":"[email protected]"}`,
schema: "file://./stub/extension/credentials/code.schema.json",
expect: []string{"[email protected]"},
existing: &identity.Credentials{
Identifiers: []string{"[email protected]"},
},
ct: identity.CredentialsTypeCodeAuth,
},
} {
t.Run(fmt.Sprintf("case=%d", k), func(t *testing.T) {
c := jsonschema.NewCompiler()
Expand Down
20 changes: 20 additions & 0 deletions identity/stub/extension/credentials/code.schema.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
{
"type": "object",
"properties": {
"email": {
"type": "string",
"format": "email",
"ory.sh/kratos": {
"credentials": {
"password": {
"identifier": true
},
"code": {
"identifier": true,
"via": "email"
}
}
}
}
}
}
123 changes: 123 additions & 0 deletions persistence/sql/persister_code.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,123 @@
// Copyright © 2023 Ory Corp
// SPDX-License-Identifier: Apache-2.0

package sql

import (
"context"
"crypto/subtle"
"fmt"
"time"

"github.com/gobuffalo/pop/v6"
"github.com/gofrs/uuid"
"github.com/pkg/errors"

"github.com/ory/kratos/selfservice/strategy/code"
"github.com/ory/x/sqlcon"
)

type oneTimeCodeProvider interface {
GetID() uuid.UUID
Validate() error
TableName(ctx context.Context) string
GetHMACCode() string
}

type codeOptions struct {
IdentityID *uuid.UUID
}

type codeOption func(o *codeOptions)

func withCheckIdentityID(id uuid.UUID) codeOption {
return func(o *codeOptions) {
o.IdentityID = &id
}
}

func useOneTimeCode[P any, U interface {
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This code was the same for login, verification, recovery, and sign up, so I moved it into a common function. Uses some generic magic to make it work nicely. Tried it with duck typing only at first, but it caused issues and would have required reflection.

*P
oneTimeCodeProvider
}](ctx context.Context, p *Persister, flowID uuid.UUID, userProvidedCode string, flowTableName string, foreignKeyName string, opts ...codeOption) (U, error) {
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here, Generics can not be avoided because we need to work with and modify the concrete underlying data type, and interfaces have certain limitations such as dealing with uninitialized values.

ctx, span := p.r.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.useOneTimeCode")
defer span.End()

o := new(codeOptions)
for _, opt := range opts {
opt(o)
}

var target U
nid := p.NetworkID(ctx)
if err := p.Transaction(ctx, func(ctx context.Context, tx *pop.Connection) error {
//#nosec G201 -- TableName is static
if err := tx.RawQuery(fmt.Sprintf("UPDATE %s SET submit_count = submit_count + 1 WHERE id = ? AND nid = ?", flowTableName), flowID, nid).Exec(); err != nil {
return err
}

var submitCount int
// Because MySQL does not support "RETURNING" clauses, but we need the updated `submit_count` later on.
//#nosec G201 -- TableName is static
if err := sqlcon.HandleError(tx.RawQuery(fmt.Sprintf("SELECT submit_count FROM %s WHERE id = ? AND nid = ?", flowTableName), flowID, nid).First(&submitCount)); err != nil {
if errors.Is(err, sqlcon.ErrNoRows) {
// Return no error, as that would roll back the transaction
return nil
}
return err
}

// This check prevents parallel brute force attacks by checking the submit count inside this database
// transaction. If the flow has been submitted more than 5 times, the transaction is aborted (regardless of
// whether the code was correct or not) and we thus give no indication whether the supplied code was correct or
// not. For more explanation see [this comment](https://github.com/ory/kratos/pull/2645#discussion_r984732899).
if submitCount > 5 {
return errors.WithStack(code.ErrCodeSubmittedTooOften)
}

var codes []U
codesQuery := tx.Where(fmt.Sprintf("nid = ? AND %s = ?", foreignKeyName), nid, flowID)
if o.IdentityID != nil {
codesQuery = codesQuery.Where("identity_id = ?", *o.IdentityID)
}

if err := sqlcon.HandleError(codesQuery.All(&codes)); err != nil {
if errors.Is(err, sqlcon.ErrNoRows) {
// Return no error, as that would roll back the transaction and reset the submit count.
return nil
}

return err
}

secrets:
for _, secret := range p.r.Config().SecretsSession(ctx) {
suppliedCode := []byte(p.hmacValueWithSecret(ctx, userProvidedCode, secret))
for i := range codes {
c := codes[i]
if subtle.ConstantTimeCompare([]byte(c.GetHMACCode()), suppliedCode) == 0 {
// Not the supplied code
continue
}
target = c
break secrets
}
}

if target.Validate() != nil {
// Return no error, as that would roll back the transaction
return nil
}

//#nosec G201 -- TableName is static
return tx.RawQuery(fmt.Sprintf("UPDATE %s SET used_at = ? WHERE id = ? AND nid = ?", target.TableName(ctx)), time.Now().UTC(), target.GetID(), nid).Exec()
}); err != nil {
return nil, sqlcon.HandleError(err)
}

if err := target.Validate(); err != nil {
return nil, err
}

return target, nil
}
125 changes: 0 additions & 125 deletions persistence/sql/persister_login.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,21 +5,16 @@ package sql

import (
"context"
"crypto/subtle"
"fmt"
"time"

"github.com/gobuffalo/pop/v6"
"github.com/pkg/errors"

"github.com/gofrs/uuid"

"github.com/ory/x/sqlcon"

"github.com/ory/kratos/persistence/sql/update"
"github.com/ory/kratos/selfservice/flow"
"github.com/ory/kratos/selfservice/flow/login"
"github.com/ory/kratos/selfservice/strategy/code"
)

var _ login.FlowPersister = new(Persister)
Expand Down Expand Up @@ -88,123 +83,3 @@ func (p *Persister) DeleteExpiredLoginFlows(ctx context.Context, expiresAt time.
}
return nil
}

func (p *Persister) CreateLoginCode(ctx context.Context, codeParams *code.CreateLoginCodeParams) (*code.LoginCode, error) {
ctx, span := p.r.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.CreateLoginCode")
defer span.End()

now := time.Now().UTC()
loginCode := &code.LoginCode{
IdentityID: codeParams.IdentityID,
Address: codeParams.Address,
AddressType: codeParams.AddressType,
CodeHMAC: p.hmacValue(ctx, codeParams.RawCode),
IssuedAt: now,
ExpiresAt: now.UTC().Add(p.r.Config().SelfServiceCodeMethodLifespan(ctx)),
FlowID: codeParams.FlowID,
NID: p.NetworkID(ctx),
ID: uuid.Nil,
}

if err := p.GetConnection(ctx).Create(loginCode); err != nil {
return nil, sqlcon.HandleError(err)
}
return loginCode, nil
}

func (p *Persister) UseLoginCode(ctx context.Context, flowID uuid.UUID, identityID uuid.UUID, codeVal string) (*code.LoginCode, error) {
ctx, span := p.r.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.UseLoginCode")
defer span.End()

var loginCode *code.LoginCode

nid := p.NetworkID(ctx)
flowTableName := new(login.Flow).TableName(ctx)

if err := sqlcon.HandleError(p.Transaction(ctx, func(ctx context.Context, tx *pop.Connection) (err error) {
//#nosec G201 -- TableName is static
if err := sqlcon.HandleError(tx.RawQuery(fmt.Sprintf("UPDATE %s SET submit_count = submit_count + 1 WHERE id = ? AND nid = ?", flowTableName), flowID, nid).Exec()); err != nil {
return err
}

var submitCount int
// Because MySQL does not support "RETURNING" clauses, but we need the updated `submit_count` later on.
//#nosec G201 -- TableName is static
if err := sqlcon.HandleError(tx.RawQuery(fmt.Sprintf("SELECT submit_count FROM %s WHERE id = ? AND nid = ?", flowTableName), flowID, nid).First(&submitCount)); err != nil {
if errors.Is(err, sqlcon.ErrNoRows) {
// Return no error, as that would roll back the transaction
return nil
}
return err
}

if submitCount > 5 {
return errors.WithStack(code.ErrCodeSubmittedTooOften)
}

var loginCodes []code.LoginCode
if err = sqlcon.HandleError(tx.Where("nid = ? AND selfservice_login_flow_id = ? AND identity_id = ?", nid, flowID, identityID).All(&loginCodes)); err != nil {
if errors.Is(err, sqlcon.ErrNoRows) {
return err
}
return nil
}

secrets:
for _, secret := range p.r.Config().SecretsSession(ctx) {
suppliedCode := []byte(p.hmacValueWithSecret(ctx, codeVal, secret))
for i := range loginCodes {
code := loginCodes[i]
if subtle.ConstantTimeCompare([]byte(code.CodeHMAC), suppliedCode) == 0 {
// Not the supplied code
continue
}
loginCode = &code
break secrets
}
}

if loginCode == nil || !loginCode.IsValid() {
// Return no error, as that would roll back the transaction
return nil
}

//#nosec G201 -- TableName is static
return sqlcon.HandleError(tx.RawQuery(fmt.Sprintf("UPDATE %s SET used_at = ? WHERE id = ? AND nid = ?", loginCode.TableName(ctx)), time.Now().UTC(), loginCode.ID, nid).Exec())
})); err != nil {
return nil, err
}

if loginCode == nil {
return nil, errors.WithStack(code.ErrCodeNotFound)
}

if loginCode.IsExpired() {
return nil, errors.WithStack(flow.NewFlowExpiredError(loginCode.ExpiresAt))
}

if loginCode.WasUsed() {
return nil, errors.WithStack(code.ErrCodeAlreadyUsed)
}

return loginCode, nil
}

func (p *Persister) DeleteLoginCodesOfFlow(ctx context.Context, flowID uuid.UUID) error {
ctx, span := p.r.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.DeleteLoginCodesOfFlow")
defer span.End()

//#nosec G201 -- TableName is static
return p.GetConnection(ctx).RawQuery(fmt.Sprintf("DELETE FROM %s WHERE selfservice_login_flow_id = ? AND nid = ?", new(code.LoginCode).TableName(ctx)), flowID, p.NetworkID(ctx)).Exec()
}

func (p *Persister) GetUsedLoginCode(ctx context.Context, flowID uuid.UUID) (*code.LoginCode, error) {
ctx, span := p.r.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.GetUsedLoginCode")
defer span.End()

var loginCode code.LoginCode
if err := p.Connection(ctx).RawQuery(fmt.Sprintf("SELECT * FROM %s WHERE selfservice_login_flow_id = ? AND nid = ? AND used_at IS NOT NULL", new(code.LoginCode).TableName(ctx)), flowID, p.NetworkID(ctx)).First(&loginCode); err != nil {
return nil, sqlcon.HandleError(err)
}
return &loginCode, nil
}
Loading
Loading