Skip to content

Commit

Permalink
Merge pull request #10 from mkmba-nz/fix-9
Browse files Browse the repository at this point in the history
avoid re-using stale state in parallel auth flows
  • Loading branch information
btoews authored Oct 7, 2024
2 parents 53c253e + b38acba commit 822a18a
Show file tree
Hide file tree
Showing 5 changed files with 101 additions and 56 deletions.
13 changes: 2 additions & 11 deletions context.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,19 +10,10 @@ import (
type contextKey string

const (
contextKeyTransaction contextKey = "transaction"
contextKeyProvider contextKey = "provider"
contextKeyLog contextKey = "log"
contextKeyProvider contextKey = "provider"
contextKeyLog contextKey = "log"
)

func withTransaction(r *http.Request, t *Transaction) *http.Request {
return r.WithContext(context.WithValue(r.Context(), contextKeyTransaction, t))
}

func GetTransaction(r *http.Request) *Transaction {
return r.Context().Value(contextKeyTransaction).(*Transaction)
}

func withProvider(r *http.Request, p *provider) *http.Request {
return r.WithContext(context.WithValue(r.Context(), contextKeyProvider, p))
}
Expand Down
10 changes: 8 additions & 2 deletions oauth2/oauth2.go
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,10 @@ func (p *provider) ServeHTTP(w http.ResponseWriter, r *http.Request) {
func (p *provider) handleStart(w http.ResponseWriter, r *http.Request) {
defer getLog(r).WithField("status", http.StatusFound).Info()

tr := ssokenizer.GetTransaction(r)
tr := ssokenizer.StartTransaction(w, r)
if tr == nil {
return
}

opts := []oauth2.AuthCodeOption{oauth2.AccessTypeOffline}

Expand All @@ -89,7 +92,10 @@ func (p *provider) handleStart(w http.ResponseWriter, r *http.Request) {
}

func (p *provider) handleCallback(w http.ResponseWriter, r *http.Request) {
tr := ssokenizer.GetTransaction(r)
tr := ssokenizer.RestoreTransaction(w, r)
if tr == nil {
return
}
params := r.URL.Query()

if errParam := params.Get("error"); errParam != "" {
Expand Down
60 changes: 49 additions & 11 deletions oauth2/oauth2_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,9 +26,11 @@ func init() {
logrus.SetLevel(logrus.DebugLevel)
}

func TestOauth2(t *testing.T) {
const rpAuth = "555"

func setupTestServers(t *testing.T) (*httptest.Server, *ssokenizer.Server, *httptest.Server, *httptest.Server) {
rpServer := httptest.NewServer(rp)
defer rpServer.Close()
t.Cleanup(rpServer.Close)
t.Logf("rp=%s", rpServer.URL)

idpServer := httptest.NewServer(idp)
Expand All @@ -41,8 +43,6 @@ func TestOauth2(t *testing.T) {
openKey = hex.EncodeToString(priv[:])
)

const rpAuth = "555"

tkz := tokenizer.NewTokenizer(openKey)
tkz.Tr = http.DefaultTransport.(*http.Transport) // disable TLS requirement for app server
tkzServer := httptest.NewServer(tkz)
Expand All @@ -69,15 +69,14 @@ func TestOauth2(t *testing.T) {
Scopes: []string{"my scope"},
},
}, rpServer.URL, tokenizer.NewBearerAuthConfig(rpAuth)))
return rpServer, skz, tkzServer, idpServer
}

client := new(http.Client)
client.Jar, _ = cookiejar.New(nil)
client.Jar = noSecureJar{client.Jar}

resp, err := client.Get("http://" + skz.Address + "/idp/start")
assert.NoError(t, err)
func checkResponse(t *testing.T, resp *http.Response, expectedPrefix, expectedState string) string {
assert.Equal(t, http.StatusOK, resp.StatusCode)
assert.True(t, strings.HasPrefix(resp.Request.URL.String(), rpServer.URL))
assert.True(t, strings.HasPrefix(resp.Request.URL.String(), expectedPrefix))
state := resp.Request.URL.Query().Get("state")
assert.Equal(t, expectedState, state)
errMsg := resp.Request.URL.Query().Get("error")
assert.Equal(t, "", errMsg)
sealed := resp.Request.URL.Query().Get("sealed")
Expand All @@ -87,6 +86,19 @@ func TestOauth2(t *testing.T) {
assert.NoError(t, err)
expires := time.Unix(iexpires, 0)
assert.Equal(t, 3599, time.Until(expires)/time.Second)
return sealed
}

func TestOauth2(t *testing.T) {
rpServer, skz, tkzServer, idpServer := setupTestServers(t)

client := new(http.Client)
client.Jar, _ = cookiejar.New(nil)
client.Jar = noSecureJar{client.Jar}

resp, err := client.Get("http://" + skz.Address + "/idp/start")
assert.NoError(t, err)
sealed := checkResponse(t, resp, rpServer.URL, "")

tkzClient, err := tokenizer.Client(tkzServer.URL, tokenizer.WithAuth(rpAuth), tokenizer.WithSecret(sealed, nil))
assert.NoError(t, err)
Expand All @@ -113,6 +125,32 @@ func TestOauth2(t *testing.T) {
assert.Equal(t, http.StatusOK, resp.StatusCode)
}

// tests that when two parallel flows are initiated, they do not interfere and the second can
// complete successfully.
func TestOauth2Parallel(t *testing.T) {
rpServer, skz, _, idpServer := setupTestServers(t)

sharedJar, _ := cookiejar.New(nil)

clientA := new(http.Client)
clientA.CheckRedirect = func(req *http.Request, via []*http.Request) error {
if strings.HasPrefix(req.URL.String(), idpServer.URL) {
return nil // follow redirect to idp
}
return http.ErrUseLastResponse // don't follow redirect back from idp, simulating abandoned flow.
}
clientA.Jar = noSecureJar{sharedJar}
_, err := clientA.Get("http://" + skz.Address + "/idp/start?state=first")
assert.NoError(t, err)

clientB := new(http.Client)
clientB.Jar = noSecureJar{sharedJar}

resp, err := clientB.Get("http://" + skz.Address + "/idp/start?state=second")
assert.NoError(t, err)
checkResponse(t, resp, rpServer.URL, "second")
}

const (
testClientID = "my-client-id"
testClientSecret = "my-client-secret"
Expand Down
32 changes: 0 additions & 32 deletions ssokenizer.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,11 @@ package ssokenizer

import (
"context"
"errors"
"fmt"
"net"
"net/http"
"net/url"
"strings"
"time"

"github.com/sirupsen/logrus"
"github.com/superfly/tokenizer"
Expand Down Expand Up @@ -56,36 +54,6 @@ func (s *Server) ServeHTTP(w http.ResponseWriter, r *http.Request) {
return
}
r = withProvider(r, provider)

t := &Transaction{
ReturnState: r.URL.Query().Get("state"),
Nonce: randHex(16),
Expiry: time.Now().Add(transactionTTL),
}

if tc, err := r.Cookie(transactionCookieName); err != http.ErrNoCookie && tc.Value != "" {
if err := unmarshalTransaction(t, tc.Value); err != nil {
r = WithError(r, fmt.Errorf("bad transaction cookie: %w", err))
t.ReturnError(w, r, "bad request")
return
}

if time.Now().After(t.Expiry) {
r = WithError(r, errors.New("expired transaction"))
t.ReturnError(w, r, "expired")
return
}
}

ts, err := t.marshal()
if err != nil {
r = WithError(r, fmt.Errorf("marshal transaction cookie: %w", err))
t.ReturnError(w, r, "unexpected error")
return
}

t.setCookie(w, r, ts)
r = withTransaction(r, t)
r.URL.Path = "/" + rest

provider.handler.ServeHTTP(w, r)
Expand Down
42 changes: 42 additions & 0 deletions transaction.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@ import (
"crypto/rand"
"encoding/base64"
"encoding/hex"
"errors"
"fmt"
"net/http"
"time"

Expand Down Expand Up @@ -63,6 +65,46 @@ func (t *Transaction) returnData(w http.ResponseWriter, r *http.Request, data ma
http.Redirect(w, r, returnURL.String(), http.StatusFound)
}

func StartTransaction(w http.ResponseWriter, r *http.Request) *Transaction {
t := &Transaction{
ReturnState: r.URL.Query().Get("state"),
Nonce: randHex(16),
Expiry: time.Now().Add(transactionTTL),
}

ts, err := t.marshal()
if err != nil {
r = WithError(r, fmt.Errorf("marshal transaction cookie: %w", err))
t.ReturnError(w, r, "unexpected error")
return nil
}

t.setCookie(w, r, ts)
return t
}

func RestoreTransaction(w http.ResponseWriter, r *http.Request) *Transaction {
t := &Transaction{
Nonce: randHex(16),
Expiry: time.Now().Add(transactionTTL),
}

if tc, err := r.Cookie(transactionCookieName); err != http.ErrNoCookie && tc.Value != "" {
if err := unmarshalTransaction(t, tc.Value); err != nil {
r = WithError(r, fmt.Errorf("bad transaction cookie: %w", err))
t.ReturnError(w, r, "bad request")
return nil
}

if time.Now().After(t.Expiry) {
r = WithError(r, errors.New("expired transaction"))
t.ReturnError(w, r, "expired")
return nil
}
}
return t
}

func unmarshalTransaction(t *Transaction, s string) error {
m, err := base64.StdEncoding.DecodeString(s)
if err != nil {
Expand Down

0 comments on commit 822a18a

Please sign in to comment.