Skip to content

Commit

Permalink
chore: synchronize workspaces
Browse files Browse the repository at this point in the history
  • Loading branch information
aeneasr committed Aug 25, 2023
1 parent fb4d01d commit ca58dee
Show file tree
Hide file tree
Showing 14 changed files with 280 additions and 33 deletions.
23 changes: 10 additions & 13 deletions persistence/sql/persister_code.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,12 +19,14 @@ type oneTimeCodeProvider interface {
GetHMACCode() string
}

func useOneTimeCode[T oneTimeCodeProvider](ctx context.Context, p *Persister, flowID uuid.UUID, userProvidedCode string, target T, flowTableName string, foreignKeyName string) error {
func useOneTimeCode[P any, U interface {
*P
oneTimeCodeProvider
}](ctx context.Context, p *Persister, flowID uuid.UUID, userProvidedCode string, flowTableName string, foreignKeyName string) (U, error) {
ctx, span := p.r.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.useOneTimeCode")
defer span.End()

var found bool

var target U
nid := p.NetworkID(ctx)
if err := p.Transaction(ctx, func(ctx context.Context, tx *pop.Connection) error {
//#nosec G201 -- TableName is static
Expand All @@ -51,7 +53,7 @@ func useOneTimeCode[T oneTimeCodeProvider](ctx context.Context, p *Persister, fl
return errors.WithStack(code.ErrCodeSubmittedTooOften)
}

var codes []T
var codes []U
if err := sqlcon.HandleError(tx.Where(fmt.Sprintf("nid = ? AND %s = ?", foreignKeyName), nid, flowID).All(&codes)); err != nil {
if errors.Is(err, sqlcon.ErrNoRows) {
// Return no error, as that would roll back the transaction and reset the submit count.
Expand All @@ -71,29 +73,24 @@ func useOneTimeCode[T oneTimeCodeProvider](ctx context.Context, p *Persister, fl
continue
}
target = c
found = true
break secrets
}
}

if !found || target.Validate() != nil {
if target.Validate() != nil {
// Return no error, as that would roll back the transaction
return nil
}

//#nosec G201 -- TableName is static
return tx.RawQuery(fmt.Sprintf("UPDATE %s SET used_at = ? WHERE id = ? AND nid = ?", target.TableName(ctx)), time.Now().UTC(), target.GetID(), nid).Exec()
}); err != nil {
return sqlcon.HandleError(err)
}

if !found {
return errors.WithStack(code.ErrCodeNotFound)
return nil, sqlcon.HandleError(err)
}

if err := target.Validate(); err != nil {
return err
return nil, err
}

return nil
return target, nil
}
4 changes: 2 additions & 2 deletions persistence/sql/persister_login_code.go
Original file line number Diff line number Diff line change
Expand Up @@ -37,8 +37,8 @@ func (p *Persister) UseLoginCode(ctx context.Context, flowID uuid.UUID, identity
ctx, span := p.r.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.UseLoginCode")
defer span.End()

var codeRow *code.LoginCode // This has to be nil per default
if err := useOneTimeCode[*code.LoginCode](ctx, p, flowID, userProvidedCode, codeRow, new(login.Flow).TableName(ctx), "selfservice_login_flow_id"); err != nil {
codeRow, err := useOneTimeCode[code.LoginCode, *code.LoginCode](ctx, p, flowID, userProvidedCode, new(login.Flow).TableName(ctx), "selfservice_login_flow_id")

Check failure on line 40 in persistence/sql/persister_login_code.go

View workflow job for this annotation

GitHub Actions / Run tests and lints

SA4006: this value of `codeRow` is never used (staticcheck)
if err != nil {
return nil, err
}

Expand Down
4 changes: 2 additions & 2 deletions persistence/sql/persister_recovery_code.go
Original file line number Diff line number Diff line change
Expand Up @@ -52,8 +52,8 @@ func (p *Persister) UseRecoveryCode(ctx context.Context, flowID uuid.UUID, userP
ctx, span := p.r.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.UseRecoveryCode")
defer span.End()

var codeRow *code.RecoveryCode // This has to be nil per default
if err := useOneTimeCode[*code.RecoveryCode](ctx, p, flowID, userProvidedCode, codeRow, new(recovery.Flow).TableName(ctx), "selfservice_recovery_flow_id"); err != nil {
codeRow, err := useOneTimeCode[code.RecoveryCode, *code.RecoveryCode](ctx, p, flowID, userProvidedCode, new(recovery.Flow).TableName(ctx), "selfservice_recovery_flow_id")
if err != nil {
return nil, err
}

Expand Down
4 changes: 2 additions & 2 deletions persistence/sql/persister_registration_code.go
Original file line number Diff line number Diff line change
Expand Up @@ -38,8 +38,8 @@ func (p *Persister) UseRegistrationCode(ctx context.Context, flowID uuid.UUID, u
ctx, span := p.r.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.UseRegistrationCode")
defer span.End()

var codeRow *code.RegistrationCode // This has to be nil per default
if err := useOneTimeCode[*code.RegistrationCode](ctx, p, flowID, userProvidedCode, codeRow, new(registration.Flow).TableName(ctx), "selfservice_registration_flow_id"); err != nil {
codeRow, err := useOneTimeCode[code.RegistrationCode, *code.RegistrationCode](ctx, p, flowID, userProvidedCode, new(registration.Flow).TableName(ctx), "selfservice_registration_flow_id")
if err != nil {
return nil, err
}

Expand Down
4 changes: 2 additions & 2 deletions persistence/sql/persister_verification_code.go
Original file line number Diff line number Diff line change
Expand Up @@ -49,8 +49,8 @@ func (p *Persister) UseVerificationCode(ctx context.Context, flowID uuid.UUID, u
ctx, span := p.r.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.UseVerificationCode")
defer span.End()

var codeRow *code.VerificationCode // This has to be nil per default
if err := useOneTimeCode[*code.VerificationCode](ctx, p, flowID, userProvidedCode, codeRow, new(verification.Flow).TableName(ctx), "selfservice_verification_flow_id"); err != nil {
codeRow, err := useOneTimeCode[code.VerificationCode, *code.VerificationCode](ctx, p, flowID, userProvidedCode, new(verification.Flow).TableName(ctx), "selfservice_verification_flow_id")
if err != nil {
return nil, err
}

Expand Down
3 changes: 3 additions & 0 deletions selfservice/strategy/code/code_login.go
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,9 @@ func (LoginCode) TableName(ctx context.Context) string {
}

func (f *LoginCode) Validate() error {
if f == nil {
return errors.WithStack(ErrCodeNotFound)
}
if f.ExpiresAt.Before(time.Now().UTC()) {
return errors.WithStack(flow.NewFlowExpiredError(f.ExpiresAt))
}
Expand Down
76 changes: 76 additions & 0 deletions selfservice/strategy/code/code_login_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
package code_test

import (
"database/sql"
"github.com/ory/kratos/internal"
"github.com/ory/kratos/selfservice/flow"
"github.com/ory/kratos/selfservice/flow/login"
"github.com/ory/kratos/selfservice/strategy/code"
"github.com/ory/kratos/x"
"github.com/ory/x/urlx"
"github.com/stretchr/testify/require"
"net/http"
"testing"
"time"
)

func TestLoginCode(t *testing.T) {
conf, _ := internal.NewFastRegistryWithMocks(t)

newCode := func(expiresIn time.Duration, f *login.Flow) *code.LoginCode {
return &code.LoginCode{
ID: x.NewUUID(),
FlowID: f.ID,
ExpiresAt: time.Now().Add(expiresIn),
}
}

req := &http.Request{URL: urlx.ParseOrPanic("https://www.ory.sh/")}
t.Run("method=Validate", func(t *testing.T) {
t.Parallel()

t.Run("case=returns error if flow is expired", func(t *testing.T) {
f, err := login.NewFlow(conf, -time.Hour, "", req, flow.TypeBrowser)
require.NoError(t, err)

c := newCode(-time.Hour, f)
expected := new(flow.ExpiredError)
require.ErrorAs(t, c.Validate(), &expected)
})
t.Run("case=returns no error if flow is not expired", func(t *testing.T) {
f, err := login.NewFlow(conf, time.Hour, "", req, flow.TypeBrowser)
require.NoError(t, err)

c := newCode(time.Hour, f)
require.NoError(t, c.Validate())
})

t.Run("case=returns error if flow has been used", func(t *testing.T) {
f, err := login.NewFlow(conf, -time.Hour, "", req, flow.TypeBrowser)
require.NoError(t, err)

c := newCode(time.Hour, f)
c.UsedAt = sql.NullTime{
Time: time.Now(),
Valid: true,
}
require.ErrorIs(t, c.Validate(), code.ErrCodeAlreadyUsed)
})

t.Run("case=returns no error if flow has not been used", func(t *testing.T) {
f, err := login.NewFlow(conf, -time.Hour, "", req, flow.TypeBrowser)
require.NoError(t, err)

c := newCode(time.Hour, f)
c.UsedAt = sql.NullTime{
Valid: false,
}
require.NoError(t, c.Validate())
})

t.Run("case=returns error if flow is nil", func(t *testing.T) {
var c *code.LoginCode
require.ErrorIs(t, c.Validate(), code.ErrCodeNotFound)
})
})
}
3 changes: 3 additions & 0 deletions selfservice/strategy/code/code_recovery.go
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,9 @@ func (RecoveryCode) TableName(ctx context.Context) string {
}

func (f *RecoveryCode) Validate() error {
if f == nil {
return errors.WithStack(ErrCodeNotFound)
}
if f.ExpiresAt.Before(time.Now().UTC()) {
return errors.WithStack(flow.NewFlowExpiredError(f.ExpiresAt))
}
Expand Down
28 changes: 17 additions & 11 deletions selfservice/strategy/code/code_recovery_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,26 +34,26 @@ func TestRecoveryCode(t *testing.T) {
}

req := &http.Request{URL: urlx.ParseOrPanic("https://www.ory.sh/")}
t.Run("method=Validate", func(t *testing.T) {
t.Parallel()

t.Run("method=IsExpired", func(t *testing.T) {
t.Run("case=returns true if flow is expired", func(t *testing.T) {
t.Run("case=returns error if flow is expired", func(t *testing.T) {
f, err := recovery.NewFlow(conf, -time.Hour, "", req, nil, flow.TypeBrowser)
require.NoError(t, err)

c := newCode(-time.Hour, f)
require.True(t, c.IsExpired())
expected := new(flow.ExpiredError)
require.ErrorAs(t, c.Validate(), &expected)
})
t.Run("case=returns false if flow is not expired", func(t *testing.T) {
t.Run("case=returns no error if flow is not expired", func(t *testing.T) {
f, err := recovery.NewFlow(conf, time.Hour, "", req, nil, flow.TypeBrowser)
require.NoError(t, err)

c := newCode(time.Hour, f)
require.False(t, c.IsExpired())
require.NoError(t, c.Validate())
})
})

t.Run("method=WasUsed", func(t *testing.T) {
t.Run("case=returns true if flow has been used", func(t *testing.T) {
t.Run("case=returns error if flow has been used", func(t *testing.T) {
f, err := recovery.NewFlow(conf, -time.Hour, "", req, nil, flow.TypeBrowser)
require.NoError(t, err)

Expand All @@ -62,17 +62,23 @@ func TestRecoveryCode(t *testing.T) {
Time: time.Now(),
Valid: true,
}
require.True(t, c.WasUsed())
require.ErrorIs(t, c.Validate(), code.ErrCodeAlreadyUsed)
})
t.Run("case=returns false if flow has not been used", func(t *testing.T) {

t.Run("case=returns no error if flow has not been used", func(t *testing.T) {
f, err := recovery.NewFlow(conf, -time.Hour, "", req, nil, flow.TypeBrowser)
require.NoError(t, err)

c := newCode(time.Hour, f)
c.UsedAt = sql.NullTime{
Valid: false,
}
require.False(t, c.WasUsed())
require.NoError(t, c.Validate())
})

t.Run("case=returns error if flow is nil", func(t *testing.T) {
var c *code.RecoveryCode
require.ErrorIs(t, c.Validate(), code.ErrCodeNotFound)
})
})
}
Expand Down
3 changes: 3 additions & 0 deletions selfservice/strategy/code/code_registration.go
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,9 @@ func (RegistrationCode) TableName(ctx context.Context) string {
}

func (f *RegistrationCode) Validate() error {
if f == nil {
return errors.WithStack(ErrCodeNotFound)
}
if f.ExpiresAt.Before(time.Now().UTC()) {
return errors.WithStack(flow.NewFlowExpiredError(f.ExpiresAt))
}
Expand Down
75 changes: 75 additions & 0 deletions selfservice/strategy/code/code_registration_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
package code_test

import (
"database/sql"
"github.com/ory/kratos/internal"
"github.com/ory/kratos/selfservice/flow"
"github.com/ory/kratos/selfservice/flow/registration"
"github.com/ory/kratos/selfservice/strategy/code"
"github.com/ory/kratos/x"
"github.com/ory/x/urlx"
"github.com/stretchr/testify/require"
"net/http"
"testing"
"time"
)

func TestRegistrationCode(t *testing.T) {
conf, _ := internal.NewFastRegistryWithMocks(t)
newCode := func(expiresIn time.Duration, f *registration.Flow) *code.RegistrationCode {
return &code.RegistrationCode{
ID: x.NewUUID(),
FlowID: f.ID,
ExpiresAt: time.Now().Add(expiresIn),
}
}

req := &http.Request{URL: urlx.ParseOrPanic("https://www.ory.sh/")}
t.Run("method=Validate", func(t *testing.T) {
t.Parallel()

t.Run("case=returns error if flow is expired", func(t *testing.T) {
f, err := registration.NewFlow(conf, -time.Hour, "", req, flow.TypeBrowser)
require.NoError(t, err)

c := newCode(-time.Hour, f)
expected := new(flow.ExpiredError)
require.ErrorAs(t, c.Validate(), &expected)
})
t.Run("case=returns no error if flow is not expired", func(t *testing.T) {
f, err := registration.NewFlow(conf, time.Hour, "", req, flow.TypeBrowser)
require.NoError(t, err)

c := newCode(time.Hour, f)
require.NoError(t, c.Validate())
})

t.Run("case=returns error if flow has been used", func(t *testing.T) {
f, err := registration.NewFlow(conf, -time.Hour, "", req, flow.TypeBrowser)
require.NoError(t, err)

c := newCode(time.Hour, f)
c.UsedAt = sql.NullTime{
Time: time.Now(),
Valid: true,
}
require.ErrorIs(t, c.Validate(), code.ErrCodeAlreadyUsed)
})

t.Run("case=returns no error if flow has not been used", func(t *testing.T) {
f, err := registration.NewFlow(conf, -time.Hour, "", req, flow.TypeBrowser)
require.NoError(t, err)

c := newCode(time.Hour, f)
c.UsedAt = sql.NullTime{
Valid: false,
}
require.NoError(t, c.Validate())
})

t.Run("case=returns error if flow is nil", func(t *testing.T) {
var c *code.RegistrationCode
require.ErrorIs(t, c.Validate(), code.ErrCodeNotFound)
})
})
}
7 changes: 6 additions & 1 deletion selfservice/strategy/code/code_sender.go
Original file line number Diff line number Diff line change
Expand Up @@ -70,9 +70,14 @@ func (s *Sender) SendCode(ctx context.Context, f flow.Flow, id *identity.Identit
WithSensitiveField("address", addresses).
Debugf("Preparing %s code", f.GetFlowName())

// We generate the code once and use it for all addresses. This is important because
// generating different codes per address would reduce the search space for an attacker.
//
// See also [this discussion](https://github.com/ory/kratos/pull/3378#discussion_r1305436968).
rawCode := GenerateCode()

// send to all addresses
for _, address := range addresses {
rawCode := GenerateCode()

switch f.GetFlowName() {
case flow.RegistrationFlow:
Expand Down
3 changes: 3 additions & 0 deletions selfservice/strategy/code/code_verification.go
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,9 @@ func (VerificationCode) TableName(context.Context) string {
// - If the code was already used `ErrCodeAlreadyUsed` is returnd
// - Otherwise, `nil` is returned
func (f *VerificationCode) Validate() error {
if f == nil {
return errors.WithStack(ErrCodeNotFound)
}
if f.ExpiresAt.Before(time.Now().UTC()) {
return errors.WithStack(flow.NewFlowExpiredError(f.ExpiresAt))
}
Expand Down
Loading

0 comments on commit ca58dee

Please sign in to comment.