-
Notifications
You must be signed in to change notification settings - Fork 9
/
transaction.go
150 lines (122 loc) · 3.58 KB
/
transaction.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
package ssokenizer
import (
"crypto/rand"
"encoding/base64"
"encoding/hex"
"errors"
"fmt"
"net/http"
"time"
"github.com/vmihailenco/msgpack/v5"
)
const (
transactionCookieName = "transaction"
transactionTTL = time.Hour
)
// State about the user's SSO attempt that is stored as a cookie. Cookies are
// set with per-provider paths to prevent transactions from different providers
// from interfering with each other.
type Transaction struct {
// Random state string that will be returned in our redirect to the relying
// party. This is used to prevent login-CSRF attacks.
ReturnState string
// Random string that provider implementations can use as the state
// parameter for downstream SSO flows.
Nonce string
// Time after which this transaction cookie will be ignored.
Expiry time.Time
}
// Return the user to the returnURL with the provided data set as query string
// parameters.
func (t *Transaction) ReturnData(w http.ResponseWriter, r *http.Request, data map[string]string) {
t.returnData(w, r, data)
}
// Return the user to the returnURL with the provided msg set in the `error`
// query string parameter.
func (t *Transaction) ReturnError(w http.ResponseWriter, r *http.Request, msg string) {
t.returnData(w, r, map[string]string{"error": msg})
}
func (t *Transaction) returnData(w http.ResponseWriter, r *http.Request, data map[string]string) {
defer GetLog(r).WithField("status", http.StatusFound).Info()
t.setCookie(w, r, "")
returnURL := getProvider(r).returnURL
q := returnURL.Query()
for k, v := range data {
q.Set(k, v)
}
if t.ReturnState != "" {
q.Set("state", t.ReturnState)
}
returnURL.RawQuery = q.Encode()
http.Redirect(w, r, returnURL.String(), http.StatusFound)
}
func StartTransaction(w http.ResponseWriter, r *http.Request) *Transaction {
t := &Transaction{
ReturnState: r.URL.Query().Get("state"),
Nonce: randHex(16),
Expiry: time.Now().Add(transactionTTL),
}
ts, err := t.marshal()
if err != nil {
r = WithError(r, fmt.Errorf("marshal transaction cookie: %w", err))
t.ReturnError(w, r, "unexpected error")
return nil
}
t.setCookie(w, r, ts)
return t
}
func RestoreTransaction(w http.ResponseWriter, r *http.Request) *Transaction {
t := &Transaction{
Nonce: randHex(16),
Expiry: time.Now().Add(transactionTTL),
}
if tc, err := r.Cookie(transactionCookieName); err != http.ErrNoCookie && tc.Value != "" {
if err := unmarshalTransaction(t, tc.Value); err != nil {
r = WithError(r, fmt.Errorf("bad transaction cookie: %w", err))
t.ReturnError(w, r, "bad request")
return nil
}
if time.Now().After(t.Expiry) {
r = WithError(r, errors.New("expired transaction"))
t.ReturnError(w, r, "expired")
return nil
}
}
return t
}
func unmarshalTransaction(t *Transaction, s string) error {
m, err := base64.StdEncoding.DecodeString(s)
if err != nil {
return err
}
return msgpack.Unmarshal(m, t)
}
func (t *Transaction) marshal() (string, error) {
m, err := msgpack.Marshal(t)
if err != nil {
return "", err
}
return base64.StdEncoding.EncodeToString(m), nil
}
func (t *Transaction) setCookie(w http.ResponseWriter, r *http.Request, v string) {
providerName := getProvider(r).name
tls := r.TLS != nil || r.Header.Get("X-Forwarded-Proto") == "https"
var maxAge int
if v == "" {
maxAge = -1
}
http.SetCookie(w, &http.Cookie{
Name: transactionCookieName,
Value: v,
Path: "/" + providerName,
Secure: tls,
HttpOnly: true,
SameSite: http.SameSiteLaxMode,
MaxAge: maxAge,
})
}
func randHex(n int) string {
b := make([]byte, n)
rand.Read(b)
return hex.EncodeToString(b)
}