Skip to content

Commit

Permalink
chore: code review for code strategy (magic code login) (#3456)
Browse files Browse the repository at this point in the history
  • Loading branch information
aeneasr committed Aug 29, 2023
1 parent f1a0fae commit d8a5c2f
Show file tree
Hide file tree
Showing 43 changed files with 1,252 additions and 1,106 deletions.
8 changes: 7 additions & 1 deletion Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,13 @@ test-short:

.PHONY: test-coverage
test-coverage: .bin/go-acc .bin/goveralls
go-acc -o coverage.out ./... -- -v -failfast -timeout=20m -tags sqlite
go-acc -o coverage.out ./... -- -v -failfast -timeout=20m -tags sqlite,json1

.PHONY: test-coverage-next
test-coverage-next: .bin/go-acc .bin/goveralls
go test -short -failfast -timeout=20m -tags sqlite,json1 -cover ./... --args test.gocoverdir="$$PWD/coverage"
go tool covdata percent -i=coverage
go tool covdata textfmt -i=./coverage -o coverage.new.out

# Generates the SDK
.PHONY: sdk
Expand Down
1 change: 1 addition & 0 deletions coverage/.gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
*
2 changes: 1 addition & 1 deletion driver/registry_default_registration.go
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ func (m *RegistryDefault) PostRegistrationPostPersistHooks(ctx context.Context,
}

if len(b) == initialHookCount {
// since we don't want merging hooks defined in a specific strategy and global hooks
// since we don't want merging hooks defined in a specific strategy and
// global hooks are added only if no strategy specific hooks are defined
for _, v := range m.getHooks(config.HookGlobal, m.Config().SelfServiceFlowRegistrationAfterHooks(ctx, config.HookGlobal)) {
if hook, ok := v.(registration.PostHookPostPersistExecutor); ok {
Expand Down
6 changes: 0 additions & 6 deletions identity/credentials.go
Original file line number Diff line number Diff line change
Expand Up @@ -168,12 +168,6 @@ type Credentials struct {
// Identifiers represents a list of unique identifiers this credential type matches.
Identifiers []string `json:"identifiers" db:"-"`

// IdentifierAddressType represents the type of the identifiers (e.g. email, phone).
// This is used to determine the correct courier to send messages to.
// The value is set by the code extension schema and is not persisted.
// only applicable on the login, registration with `code` method.
IdentifierAddressType CredentialsIdentifierAddressType `json:"-" db:"-"`

// Config contains the concrete credential payload. This might contain the bcrypt-hashed password, the email
// for passwordless authentication or access_token and refresh tokens from OpenID Connect flows.
Config sqlxx.JSONRawMessage `json:"config,omitempty" db:"config"`
Expand Down
2 changes: 1 addition & 1 deletion identity/credentials_code.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ const (
CodeAddressTypePhone CodeAddressType = AddressTypePhone
)

// CredentialsCode represents a one time login/registraiton code
// CredentialsCode represents a one time login/registration code
//
// swagger:model identityCredentialsCode
type CredentialsCode struct {
Expand Down
3 changes: 1 addition & 2 deletions identity/extension_credentials.go
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,6 @@ func (r *SchemaExtensionCredentials) setIdentifier(ct CredentialsType, value int

r.v[ct] = stringslice.Unique(append(r.v[ct], strings.ToLower(fmt.Sprintf("%s", value))))
cred.Identifiers = r.v[ct]
cred.IdentifierAddressType = addressType
r.i.SetCredentials(ct, *cred)
}

Expand All @@ -64,7 +63,7 @@ func (r *SchemaExtensionCredentials) Run(ctx jsonschema.ValidationContext, s sch
return ctx.Error("format", "%q is not a valid %q", value, s.Credentials.Code.Via)
}

r.setIdentifier(CredentialsTypeCodeAuth, value, CredentialsIdentifierAddressType(AddressTypeEmail))
r.setIdentifier(CredentialsTypeCodeAuth, value, AddressTypeEmail)
// case f.AddCase(AddressTypePhone):
// if !jsonschema.Formats["tel"](value) {
// return ctx.Error("format", "%q is not a valid %q", value, s.Credentials.Code.Via)
Expand Down
15 changes: 15 additions & 0 deletions identity/extension_credentials_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,21 @@ func TestSchemaExtensionCredentials(t *testing.T) {
},
ct: identity.CredentialsTypeWebAuthn,
},
{
doc: `{"email":"[email protected]"}`,
schema: "file://./stub/extension/credentials/code.schema.json",
expect: []string{"[email protected]"},
ct: identity.CredentialsTypeCodeAuth,
},
{
doc: `{"email":"[email protected]"}`,
schema: "file://./stub/extension/credentials/code.schema.json",
expect: []string{"[email protected]"},
existing: &identity.Credentials{
Identifiers: []string{"[email protected]"},
},
ct: identity.CredentialsTypeCodeAuth,
},
} {
t.Run(fmt.Sprintf("case=%d", k), func(t *testing.T) {
c := jsonschema.NewCompiler()
Expand Down
20 changes: 20 additions & 0 deletions identity/stub/extension/credentials/code.schema.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
{
"type": "object",
"properties": {
"email": {
"type": "string",
"format": "email",
"ory.sh/kratos": {
"credentials": {
"password": {
"identifier": true
},
"code": {
"identifier": true,
"via": "email"
}
}
}
}
}
}
123 changes: 123 additions & 0 deletions persistence/sql/persister_code.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,123 @@
// Copyright © 2023 Ory Corp
// SPDX-License-Identifier: Apache-2.0

package sql

import (
"context"
"crypto/subtle"
"fmt"
"time"

"github.com/gobuffalo/pop/v6"
"github.com/gofrs/uuid"
"github.com/pkg/errors"

"github.com/ory/kratos/selfservice/strategy/code"
"github.com/ory/x/sqlcon"
)

type oneTimeCodeProvider interface {
GetID() uuid.UUID
Validate() error
TableName(ctx context.Context) string
GetHMACCode() string
}

type codeOptions struct {
IdentityID *uuid.UUID
}

type codeOption func(o *codeOptions)

func withCheckIdentityID(id uuid.UUID) codeOption {
return func(o *codeOptions) {
o.IdentityID = &id
}
}

func useOneTimeCode[P any, U interface {
*P
oneTimeCodeProvider
}](ctx context.Context, p *Persister, flowID uuid.UUID, userProvidedCode string, flowTableName string, foreignKeyName string, opts ...codeOption) (U, error) {
ctx, span := p.r.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.useOneTimeCode")
defer span.End()

o := new(codeOptions)
for _, opt := range opts {
opt(o)
}

var target U
nid := p.NetworkID(ctx)
if err := p.Transaction(ctx, func(ctx context.Context, tx *pop.Connection) error {
//#nosec G201 -- TableName is static
if err := tx.RawQuery(fmt.Sprintf("UPDATE %s SET submit_count = submit_count + 1 WHERE id = ? AND nid = ?", flowTableName), flowID, nid).Exec(); err != nil {
return err
}

var submitCount int
// Because MySQL does not support "RETURNING" clauses, but we need the updated `submit_count` later on.
//#nosec G201 -- TableName is static
if err := sqlcon.HandleError(tx.RawQuery(fmt.Sprintf("SELECT submit_count FROM %s WHERE id = ? AND nid = ?", flowTableName), flowID, nid).First(&submitCount)); err != nil {
if errors.Is(err, sqlcon.ErrNoRows) {
// Return no error, as that would roll back the transaction
return nil
}
return err
}

// This check prevents parallel brute force attacks by checking the submit count inside this database
// transaction. If the flow has been submitted more than 5 times, the transaction is aborted (regardless of
// whether the code was correct or not) and we thus give no indication whether the supplied code was correct or
// not. For more explanation see [this comment](https://github.com/ory/kratos/pull/2645#discussion_r984732899).
if submitCount > 5 {
return errors.WithStack(code.ErrCodeSubmittedTooOften)
}

var codes []U
codesQuery := tx.Where(fmt.Sprintf("nid = ? AND %s = ?", foreignKeyName), nid, flowID)
if o.IdentityID != nil {
codesQuery = codesQuery.Where("identity_id = ?", *o.IdentityID)
}

if err := sqlcon.HandleError(codesQuery.All(&codes)); err != nil {
if errors.Is(err, sqlcon.ErrNoRows) {
// Return no error, as that would roll back the transaction and reset the submit count.
return nil
}

return err
}

secrets:
for _, secret := range p.r.Config().SecretsSession(ctx) {
suppliedCode := []byte(p.hmacValueWithSecret(ctx, userProvidedCode, secret))
for i := range codes {
c := codes[i]
if subtle.ConstantTimeCompare([]byte(c.GetHMACCode()), suppliedCode) == 0 {
// Not the supplied code
continue
}
target = c
break secrets
}
}

if target.Validate() != nil {
// Return no error, as that would roll back the transaction
return nil
}

//#nosec G201 -- TableName is static
return tx.RawQuery(fmt.Sprintf("UPDATE %s SET used_at = ? WHERE id = ? AND nid = ?", target.TableName(ctx)), time.Now().UTC(), target.GetID(), nid).Exec()
}); err != nil {
return nil, sqlcon.HandleError(err)
}

if err := target.Validate(); err != nil {
return nil, err
}

return target, nil
}
125 changes: 0 additions & 125 deletions persistence/sql/persister_login.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,21 +5,16 @@ package sql

import (
"context"
"crypto/subtle"
"fmt"
"time"

"github.com/gobuffalo/pop/v6"
"github.com/pkg/errors"

"github.com/gofrs/uuid"

"github.com/ory/x/sqlcon"

"github.com/ory/kratos/persistence/sql/update"
"github.com/ory/kratos/selfservice/flow"
"github.com/ory/kratos/selfservice/flow/login"
"github.com/ory/kratos/selfservice/strategy/code"
)

var _ login.FlowPersister = new(Persister)
Expand Down Expand Up @@ -88,123 +83,3 @@ func (p *Persister) DeleteExpiredLoginFlows(ctx context.Context, expiresAt time.
}
return nil
}

func (p *Persister) CreateLoginCode(ctx context.Context, codeParams *code.CreateLoginCodeParams) (*code.LoginCode, error) {
ctx, span := p.r.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.CreateLoginCode")
defer span.End()

now := time.Now().UTC()
loginCode := &code.LoginCode{
IdentityID: codeParams.IdentityID,
Address: codeParams.Address,
AddressType: codeParams.AddressType,
CodeHMAC: p.hmacValue(ctx, codeParams.RawCode),
IssuedAt: now,
ExpiresAt: now.UTC().Add(p.r.Config().SelfServiceCodeMethodLifespan(ctx)),
FlowID: codeParams.FlowID,
NID: p.NetworkID(ctx),
ID: uuid.Nil,
}

if err := p.GetConnection(ctx).Create(loginCode); err != nil {
return nil, sqlcon.HandleError(err)
}
return loginCode, nil
}

func (p *Persister) UseLoginCode(ctx context.Context, flowID uuid.UUID, identityID uuid.UUID, codeVal string) (*code.LoginCode, error) {
ctx, span := p.r.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.UseLoginCode")
defer span.End()

var loginCode *code.LoginCode

nid := p.NetworkID(ctx)
flowTableName := new(login.Flow).TableName(ctx)

if err := sqlcon.HandleError(p.Transaction(ctx, func(ctx context.Context, tx *pop.Connection) (err error) {
//#nosec G201 -- TableName is static
if err := sqlcon.HandleError(tx.RawQuery(fmt.Sprintf("UPDATE %s SET submit_count = submit_count + 1 WHERE id = ? AND nid = ?", flowTableName), flowID, nid).Exec()); err != nil {
return err
}

var submitCount int
// Because MySQL does not support "RETURNING" clauses, but we need the updated `submit_count` later on.
//#nosec G201 -- TableName is static
if err := sqlcon.HandleError(tx.RawQuery(fmt.Sprintf("SELECT submit_count FROM %s WHERE id = ? AND nid = ?", flowTableName), flowID, nid).First(&submitCount)); err != nil {
if errors.Is(err, sqlcon.ErrNoRows) {
// Return no error, as that would roll back the transaction
return nil
}
return err
}

if submitCount > 5 {
return errors.WithStack(code.ErrCodeSubmittedTooOften)
}

var loginCodes []code.LoginCode
if err = sqlcon.HandleError(tx.Where("nid = ? AND selfservice_login_flow_id = ? AND identity_id = ?", nid, flowID, identityID).All(&loginCodes)); err != nil {
if errors.Is(err, sqlcon.ErrNoRows) {
return err
}
return nil
}

secrets:
for _, secret := range p.r.Config().SecretsSession(ctx) {
suppliedCode := []byte(p.hmacValueWithSecret(ctx, codeVal, secret))
for i := range loginCodes {
code := loginCodes[i]
if subtle.ConstantTimeCompare([]byte(code.CodeHMAC), suppliedCode) == 0 {
// Not the supplied code
continue
}
loginCode = &code
break secrets
}
}

if loginCode == nil || !loginCode.IsValid() {
// Return no error, as that would roll back the transaction
return nil
}

//#nosec G201 -- TableName is static
return sqlcon.HandleError(tx.RawQuery(fmt.Sprintf("UPDATE %s SET used_at = ? WHERE id = ? AND nid = ?", loginCode.TableName(ctx)), time.Now().UTC(), loginCode.ID, nid).Exec())
})); err != nil {
return nil, err
}

if loginCode == nil {
return nil, errors.WithStack(code.ErrCodeNotFound)
}

if loginCode.IsExpired() {
return nil, errors.WithStack(flow.NewFlowExpiredError(loginCode.ExpiresAt))
}

if loginCode.WasUsed() {
return nil, errors.WithStack(code.ErrCodeAlreadyUsed)
}

return loginCode, nil
}

func (p *Persister) DeleteLoginCodesOfFlow(ctx context.Context, flowID uuid.UUID) error {
ctx, span := p.r.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.DeleteLoginCodesOfFlow")
defer span.End()

//#nosec G201 -- TableName is static
return p.GetConnection(ctx).RawQuery(fmt.Sprintf("DELETE FROM %s WHERE selfservice_login_flow_id = ? AND nid = ?", new(code.LoginCode).TableName(ctx)), flowID, p.NetworkID(ctx)).Exec()
}

func (p *Persister) GetUsedLoginCode(ctx context.Context, flowID uuid.UUID) (*code.LoginCode, error) {
ctx, span := p.r.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.GetUsedLoginCode")
defer span.End()

var loginCode code.LoginCode
if err := p.Connection(ctx).RawQuery(fmt.Sprintf("SELECT * FROM %s WHERE selfservice_login_flow_id = ? AND nid = ? AND used_at IS NOT NULL", new(code.LoginCode).TableName(ctx)), flowID, p.NetworkID(ctx)).First(&loginCode); err != nil {
return nil, sqlcon.HandleError(err)
}
return &loginCode, nil
}
Loading

0 comments on commit d8a5c2f

Please sign in to comment.