Skip to content

Commit

Permalink
fix: don't require session for OIDC verification
Browse files Browse the repository at this point in the history
  • Loading branch information
hperl committed Aug 21, 2023
1 parent 2f5ba1f commit 3cd81a9
Show file tree
Hide file tree
Showing 14 changed files with 218 additions and 54 deletions.
11 changes: 11 additions & 0 deletions driver/registry_default.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import (
"net/http"
"strings"
"sync"
"testing"
"time"

"github.com/ory/x/contextx"
Expand Down Expand Up @@ -390,6 +391,16 @@ func (m *RegistryDefault) WithConfig(c *config.Config) Registry {
return m
}

// WithSelfserviceStrategies is only available in testing and overrides the
// selfservice strategies with the given ones.
func (m *RegistryDefault) WithSelfserviceStrategies(t testing.TB, strategies []any) Registry {
if t == nil {
panic("Passing selfservice strategies is only supported in testing")
}
m.selfserviceStrategies = strategies
return m
}

func (m *RegistryDefault) Writer() herodot.Writer {
if m.writer == nil {
h := herodot.NewJSONWriter(m.Logger())
Expand Down
4 changes: 2 additions & 2 deletions driver/registry_default_verification.go
Original file line number Diff line number Diff line change
Expand Up @@ -74,10 +74,10 @@ func (m *RegistryDefault) VerificationStrategies(ctx context.Context) (verificat
return
}

func (m *RegistryDefault) AllVerificationStrategies() (recoveryStrategies verification.Strategies) {
func (m *RegistryDefault) AllVerificationStrategies() (verificationStrategies verification.Strategies) {
for _, strategy := range m.selfServiceStrategies() {
if s, ok := strategy.(verification.Strategy); ok {
recoveryStrategies = append(recoveryStrategies, s)
verificationStrategies = append(verificationStrategies, s)
}
}

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
ALTER TABLE selfservice_verification_flows DROP COLUMN session_id;

Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
ALTER TABLE selfservice_verification_flows ADD COLUMN session_id VARCHAR(36);
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
ALTER TABLE selfservice_verification_flows ADD COLUMN session_id UUID;
30 changes: 15 additions & 15 deletions schema/handler_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ func TestHandler(t *testing.T) {
return s
}

getFromTS := func(url string, expectCode int) []byte {
getFromTS := func(t *testing.T, url string, expectCode int) []byte {
res, err := ts.Client().Get(url)
require.NoError(t, err)
body, err := io.ReadAll(res.Body)
Expand All @@ -92,12 +92,12 @@ func TestHandler(t *testing.T) {
return body
}

getFromTSById := func(id string, expectCode int) []byte {
return getFromTS(fmt.Sprintf("%s/schemas/%s", ts.URL, id), expectCode)
getFromTSById := func(t *testing.T, id string, expectCode int) []byte {
return getFromTS(t, fmt.Sprintf("%s/schemas/%s", ts.URL, id), expectCode)
}

getFromTSPaginated := func(page, perPage, expectCode int) []byte {
return getFromTS(fmt.Sprintf("%s/schemas?page=%d&per_page=%d", ts.URL, page, perPage), expectCode)
getFromTSPaginated := func(t *testing.T, page, perPage, expectCode int) []byte {
return getFromTS(t, fmt.Sprintf("%s/schemas?page=%d&per_page=%d", ts.URL, page, perPage), expectCode)
}

getFromFS := func(id string) []byte {
Expand Down Expand Up @@ -132,46 +132,46 @@ func TestHandler(t *testing.T) {
setSchemas(schemas)

t.Run("case=get default schema", func(t *testing.T) {
server := getFromTSById(config.DefaultIdentityTraitsSchemaID, http.StatusOK)
server := getFromTSById(t, config.DefaultIdentityTraitsSchemaID, http.StatusOK)
file := getFromFS(config.DefaultIdentityTraitsSchemaID)
require.JSONEq(t, string(file), string(server))
})

t.Run("case=get other schema", func(t *testing.T) {
server := getFromTSById("identity2", http.StatusOK)
server := getFromTSById(t, "identity2", http.StatusOK)
file := getFromFS("identity2")
require.JSONEq(t, string(file), string(server))
})

t.Run("case=get base64 schema", func(t *testing.T) {
server := getFromTSById("base64", http.StatusOK)
server := getFromTSById(t, "base64", http.StatusOK)
file := getFromFS("base64")
require.JSONEq(t, string(file), string(server))
})

t.Run("case=get encoded schema", func(t *testing.T) {
server := getFromTSById("cHJlc2V0Oi8vZW1haWw", http.StatusOK)
server := getFromTSById(t, "cHJlc2V0Oi8vZW1haWw", http.StatusOK)
file := getFromFS("preset://email")
require.JSONEq(t, string(file), string(server))
})

t.Run("case=get unreachable schema", func(t *testing.T) {
reason := getFromTSById("unreachable", http.StatusInternalServerError)
reason := getFromTSById(t, "unreachable", http.StatusInternalServerError)
require.Contains(t, string(reason), "could not be found or opened")
})

t.Run("case=get no-file schema", func(t *testing.T) {
reason := getFromTSById("no-file", http.StatusInternalServerError)
reason := getFromTSById(t, "no-file", http.StatusInternalServerError)
require.Contains(t, string(reason), "could not be found or opened")
})

t.Run("case=get directory schema", func(t *testing.T) {
reason := getFromTSById("directory", http.StatusInternalServerError)
reason := getFromTSById(t, "directory", http.StatusInternalServerError)
require.Contains(t, string(reason), "could not be found or opened")
})

t.Run("case=get not-existing schema", func(t *testing.T) {
_ = getFromTSById("not-existing", http.StatusNotFound)
_ = getFromTSById(t, "not-existing", http.StatusNotFound)
})

t.Run("case=get all schemas", func(t *testing.T) {
Expand All @@ -188,7 +188,7 @@ func TestHandler(t *testing.T) {
},
})

body := getFromTSPaginated(0, 2, http.StatusOK)
body := getFromTSPaginated(t, 0, 2, http.StatusOK)

var result []client.IdentitySchemaContainer
require.NoError(t, json.Unmarshal(body, &result))
Expand Down Expand Up @@ -230,7 +230,7 @@ func TestHandler(t *testing.T) {
},
})

body1, body2 := getFromTSPaginated(0, 1, http.StatusOK), getFromTSPaginated(1, 1, http.StatusOK)
body1, body2 := getFromTSPaginated(t, 0, 1, http.StatusOK), getFromTSPaginated(t, 1, 1, http.StatusOK)

var result1, result2 schema.IdentitySchemas
require.NoError(t, json.Unmarshal(body1, &result1))
Expand Down
8 changes: 1 addition & 7 deletions selfservice/flow/flow.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,9 @@ import (

"github.com/pkg/errors"

"github.com/ory/herodot"
"github.com/ory/kratos/driver/config"
"github.com/ory/kratos/ui/container"
"github.com/ory/x/sqlxx"

"github.com/ory/herodot"
"github.com/ory/kratos/x"

"github.com/gofrs/uuid"
Expand Down Expand Up @@ -42,10 +40,6 @@ type Flow interface {
GetUI() *container.Container
}

type Challenger interface {
GetOAuth2LoginChallenge() sqlxx.NullString
}

type FlowWithRedirect interface {
SecureRedirectToOpts(ctx context.Context, cfg config.Provider) (opts []x.SecureRedirectOption)
}
6 changes: 1 addition & 5 deletions selfservice/flow/registration/flow.go
Original file line number Diff line number Diff line change
Expand Up @@ -154,7 +154,7 @@ func NewFlow(conf *config.Config, exp time.Duration, csrf string, r *http.Reques
}, nil
}

func (f Flow) TableName(ctx context.Context) string {
func (f Flow) TableName(context.Context) string {
return "selfservice_registration_flows"
}

Expand Down Expand Up @@ -229,10 +229,6 @@ func (f *Flow) ContinueWith() []flow.ContinueWith {
return f.ContinueWithItems
}

func (f Flow) GetOAuth2LoginChallenge() sqlxx.NullString {
return f.OAuth2LoginChallenge
}

func (f *Flow) SecureRedirectToOpts(ctx context.Context, cfg config.Provider) (opts []x.SecureRedirectOption) {
return []x.SecureRedirectOption{
x.SecureRedirectReturnTo(f.ReturnTo),
Expand Down
34 changes: 34 additions & 0 deletions selfservice/flow/verification/fake_strategy.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
// Copyright © 2023 Ory Corp
// SPDX-License-Identifier: Apache-2.0

package verification

import (
"context"
"net/http"

"github.com/ory/kratos/identity"
"github.com/ory/kratos/ui/node"
)

type FakeStrategy struct{}

func (f FakeStrategy) VerificationStrategyID() string {
return "fake"
}

func (f FakeStrategy) VerificationNodeGroup() node.UiNodeGroup {
return "fake"
}

func (f FakeStrategy) PopulateVerificationMethod(*http.Request, *Flow) error {
return nil
}

func (f FakeStrategy) Verify(_ http.ResponseWriter, _ *http.Request, _ *Flow) (err error) {
return nil
}

func (f FakeStrategy) SendVerificationEmail(context.Context, *Flow, *identity.Identity, *identity.VerifiableAddress) error {
return nil
}
6 changes: 5 additions & 1 deletion selfservice/flow/verification/flow.go
Original file line number Diff line number Diff line change
Expand Up @@ -79,8 +79,12 @@ type Flow struct {
// required: true
State State `json:"state" faker:"-" db:"state"`

// OAuth2LoginChallenge holds the login challenge originally set during the registration flow.
OAuth2LoginChallenge sqlxx.NullString `json:"-" db:"oauth2_login_challenge"`

// SessionID holds the session id if set from a registraton hook.
SessionID uuid.NullUUID `json:"-" faker:"-" db:"session_id"`

// CSRFToken contains the anti-csrf token associated with this request.
CSRFToken string `json:"-" db:"csrf_token"`

Expand All @@ -99,7 +103,7 @@ func (f *Flow) GetRequestURL() string {
return f.RequestURL
}

func (f Flow) TableName(ctx context.Context) string {
func (f Flow) TableName(context.Context) string {
return "selfservice_verification_flows"
}

Expand Down
39 changes: 28 additions & 11 deletions selfservice/flow/verification/handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
package verification

import (
"context"
"net/http"
"time"

Expand Down Expand Up @@ -47,6 +48,7 @@ type (
identity.PrivilegedPoolProvider
config.Provider
hydra.Provider
session.PersistenceProvider
session.ManagementProvider

x.CSRFTokenGeneratorProvider
Expand Down Expand Up @@ -396,7 +398,8 @@ func (h *Handler) updateVerificationFlow(w http.ResponseWriter, r *http.Request,
return
}

f, err := h.d.VerificationFlowPersister().GetVerificationFlow(r.Context(), rid)
ctx := r.Context()
f, err := h.d.VerificationFlowPersister().GetVerificationFlow(ctx, rid)
if errors.Is(err, sqlcon.ErrNoRows) {
h.d.VerificationFlowErrorHandler().WriteFlowError(w, r, nil, node.DefaultGroup, errors.WithStack(herodot.ErrNotFound.WithReasonf("The verification request could not be found. Please restart the flow.")))
return
Expand Down Expand Up @@ -441,19 +444,20 @@ func (h *Handler) updateVerificationFlow(w http.ResponseWriter, r *http.Request,
if x.IsBrowserRequest(r) {
// Special case: If we ended up here through a OAuth2 login challenge, we need to accept the login request
// and redirect back to the OAuth2 provider.
if found && f.OAuth2LoginChallenge.String() != "" {
s, err := h.d.SessionManager().FetchFromRequest(r.Context(), r)
if err != nil {
h.d.VerificationFlowErrorHandler().WriteFlowError(w, r, f, node.DefaultGroup, err)
if f.OAuth2LoginChallenge.String() != "" {
sess := h.maybeGetSession(ctx, f)
if sess == nil {
h.d.VerificationFlowErrorHandler().WriteFlowError(w, r, f, node.DefaultGroup,
herodot.ErrBadRequest.WithReasonf("No session was found for this flow. Please retry the authentication."))
return
}

callbackURL, err := h.d.Hydra().AcceptLoginRequest(r.Context(),
callbackURL, err := h.d.Hydra().AcceptLoginRequest(ctx,
hydra.AcceptLoginRequestParams{
LoginChallenge: string(f.OAuth2LoginChallenge),
IdentityID: s.IdentityID.String(),
SessionID: s.ID.String(),
AuthenticationMethods: s.AMR,
IdentityID: sess.IdentityID.String(),
SessionID: sess.ID.String(),
AuthenticationMethods: sess.AMR,
})
if err != nil {
h.d.VerificationFlowErrorHandler().WriteFlowError(w, r, f, node.DefaultGroup, err)
Expand All @@ -464,15 +468,28 @@ func (h *Handler) updateVerificationFlow(w http.ResponseWriter, r *http.Request,
return
}

http.Redirect(w, r, f.AppendTo(h.d.Config().SelfServiceFlowVerificationUI(r.Context())).String(), http.StatusSeeOther)
http.Redirect(w, r, f.AppendTo(h.d.Config().SelfServiceFlowVerificationUI(ctx)).String(), http.StatusSeeOther)
return
}

updatedFlow, err := h.d.VerificationFlowPersister().GetVerificationFlow(r.Context(), f.ID)
updatedFlow, err := h.d.VerificationFlowPersister().GetVerificationFlow(ctx, f.ID)
if err != nil {
h.d.VerificationFlowErrorHandler().WriteFlowError(w, r, f, g, err)
return
}

h.d.Writer().Write(w, r, updatedFlow)
}

// maybeGetSession returns the session if it was found in the flow or nil otherwise.
func (h *Handler) maybeGetSession(ctx context.Context, f *Flow) *session.Session {
if !f.SessionID.Valid {
return nil
}
s, err := h.d.SessionPersister().GetSession(ctx, f.SessionID.UUID, session.ExpandNothing)
if err != nil {
return nil
}

return s
}
Loading

0 comments on commit 3cd81a9

Please sign in to comment.