-
Notifications
You must be signed in to change notification settings - Fork 139
/
Copy pathoidc_api.go
348 lines (293 loc) · 11.7 KB
/
oidc_api.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
// Copyright 2016-Present Couchbase, Inc.
//
// Use of this software is governed by the Business Source License included
// in the file licenses/BSL-Couchbase.txt. As of the Change Date specified
// in that file, in accordance with the Business Source License, use of this
// software will be governed by the Apache License, Version 2.0, included in
// the file licenses/APL2.txt.
package rest
import (
"fmt"
"net/http"
"net/url"
"time"
"github.com/couchbase/sync_gateway/auth"
"github.com/couchbase/sync_gateway/base"
"golang.org/x/oauth2"
)
const (
requestParamCode = "code"
requestParamProvider = "provider"
requestParamOffline = "offline"
requestParamError = "error"
requestParamErrorDesc = "error_description"
requestParamRefreshToken = "refresh_token"
requestParamState = "state"
requestParamScope = "scope"
requestParamRedirectURI = "redirect_uri"
// stateCookieName is the name of state cookie to prevent cross-site request forgery (CSRF).
stateCookieName = "sg-oidc-state"
// stateCookieMaxAge represents the number of seconds until the cookie expires.
// The state cookie expires in 5 minutes from when the response was generated.
stateCookieMaxAge = 300
)
var (
// ErrNoStateCookie is returned by handler's handleOIDCCallback method when
// the state cookie is not found during OpenID Connect Auth callback.
ErrNoStateCookie = base.HTTPErrorf(http.StatusBadRequest, "OIDC Auth Failure: No state cookie found, client needs to support cookies when OIDC callback state is enabled")
// ErrStateMismatch is returned by handler's handleOIDCCallback method when
// the state cookie value doesn't match with state param in the callback URL
// during OpenID Connect Auth callback.
ErrStateMismatch = base.HTTPErrorf(http.StatusBadRequest, "OIDC Auth Failure: State mismatch")
// ErrReadStateCookie is returned by handler's handleOIDCCallback method when
// there is failure reading state cookie value during OpenID Connect Auth callback.
ErrReadStateCookie = base.HTTPErrorf(http.StatusBadRequest, "OIDC Auth Failure: Couldn't read state")
)
const (
keyIDToken = "id_token"
)
type OIDCTokenResponse struct {
IDToken string `json:"id_token,omitempty"` // ID token, from OP
RefreshToken string `json:"refresh_token,omitempty"` // Refresh token, from OP
SessionID string `json:"session_id,omitempty"` // Sync Gateway session ID
Username string `json:"name,omitempty"` // Sync Gateway user name
AccessToken string `json:"access_token,omitempty"` // Access token, from OP
TokenType string `json:"token_type,omitempty"` // Access token type, from OP
Expires int `json:"expires_in,omitempty"` // Access token expiry, from OP
}
func (h *handler) handleOIDC() error {
redirectURL, err := h.handleOIDCCommon()
if err != nil {
return err
}
http.Redirect(h.response, h.rq, redirectURL, http.StatusFound)
return nil
}
func (h *handler) handleOIDCChallenge() error {
redirectURL, err := h.handleOIDCCommon()
if err != nil {
return err
}
authHeader := fmt.Sprintf("OIDC login=%q", redirectURL)
h.setHeader("WWW-Authenticate", authHeader)
return ErrLoginRequired
}
func (h *handler) handleOIDCCommon() (redirectURLString string, err error) {
providerName := h.getQuery(requestParamProvider)
base.InfofCtx(h.ctx(), base.KeyAuth, "Getting provider for name %v", base.UD(providerName))
provider, err := h.getOIDCProvider(providerName)
if err != nil || provider == nil {
return redirectURLString, err
}
client, err := provider.GetClient(h.ctx(), h.getOIDCCallbackURL)
if err != nil {
return redirectURLString, base.HTTPErrorf(
http.StatusInternalServerError, "Unable to obtain client for provider: %s - %v", providerName, err)
}
var redirectURL *url.URL
state := ""
// Set state to prevent cross-site request forgery (CSRF) when DisableCallbackState is not enabled.
if !provider.DisableCallbackState {
state, err = base.GenerateRandomSecret()
if err != nil {
return redirectURLString, err
}
stateCookie := h.makeStateCookie(state, stateCookieMaxAge)
http.SetCookie(h.response, stateCookie)
}
// TODO: Is there a use case where we need to support direct pass-through of access_type and prompt from the caller?
offline := h.getBoolQuery(requestParamOffline)
if offline {
// Set access type to offline and prompt to consent in auth code request URL.
redirectURL, err = url.Parse(client.Config().AuthCodeURL(state, oauth2.AccessTypeOffline, oauth2.ApprovalForce))
} else {
redirectURL, err = url.Parse(client.Config().AuthCodeURL(state))
}
if err != nil {
return redirectURLString, err
}
return redirectURL.String(), nil
}
func (h *handler) handleOIDCCallback() error {
callbackError := h.getQuery(requestParamError)
if callbackError != "" {
errorDescription := h.getQuery(requestParamErrorDesc)
return base.HTTPErrorf(http.StatusUnauthorized, "oidc callback received an error: %v", errorDescription)
}
code := h.getQuery(requestParamCode)
if code == "" {
return base.HTTPErrorf(http.StatusBadRequest, "Code must be present on oidc callback")
}
providerName := h.getQuery(requestParamProvider)
provider, err := h.getOIDCProvider(providerName)
if err != nil || provider == nil {
return base.HTTPErrorf(http.StatusBadRequest, "Unable to identify provider for callback request")
}
// Validate state parameter to prevent cross-site request forgery (CSRF) when callback state is enabled.
if !provider.DisableCallbackState {
stateCookie, err := h.rq.Cookie(stateCookieName)
if err == http.ErrNoCookie || stateCookie == nil {
return ErrNoStateCookie
}
if err != nil {
base.WarnfCtx(h.ctx(), "Unexpected error attempting to read OIDC state cookie: %v", err)
return ErrReadStateCookie
}
stateParam := h.rq.URL.Query().Get(requestParamState)
if stateParam != stateCookie.Value {
return ErrStateMismatch
}
// Delete the state cookie on successful validation.
stateCookie = h.makeStateCookie("", -1)
http.SetCookie(h.response, stateCookie)
}
client, err := provider.GetClient(h.ctx(), h.getOIDCCallbackURL)
if err != nil {
return fmt.Errorf("OIDC initialization error: %w", err)
}
// Converts the authorization code into a token.
context := auth.GetOIDCClientContext(provider.InsecureSkipVerify)
token, err := client.Config().Exchange(context, code)
if err != nil {
return base.HTTPErrorf(http.StatusInternalServerError, "Failed to exchange token: %s", err.Error())
}
rawIDToken, ok := token.Extra("id_token").(string)
if !ok {
return base.HTTPErrorf(http.StatusInternalServerError, "No id_token field in oauth2 token.")
}
base.InfofCtx(h.ctx(), base.KeyAuth, "Obtained token from Authorization Server: %v", rawIDToken)
// Create a Sync Gateway session
username, sessionID, err := h.createSessionForTrustedIdToken(rawIDToken, provider)
if err != nil {
return err
}
callbackResponse := &OIDCTokenResponse{
IDToken: rawIDToken,
RefreshToken: token.RefreshToken,
SessionID: sessionID,
Username: username,
}
if provider.IncludeAccessToken {
callbackResponse.AccessToken = token.AccessToken
callbackResponse.Expires = int(token.Expiry.Sub(time.Now()).Seconds())
callbackResponse.TokenType = token.TokenType
}
h.writeJSON(callbackResponse)
return nil
}
func (h *handler) handleOIDCRefresh() error {
refreshToken := h.getQuery(requestParamRefreshToken)
if refreshToken == "" {
return base.NewHTTPError(http.StatusBadRequest, "Refresh token must be present for oidc refresh")
}
providerName := h.getQuery(requestParamProvider)
provider, err := h.getOIDCProvider(providerName)
if err != nil || provider == nil {
return base.HTTPErrorf(http.StatusBadRequest, "Unable to identify provider for callback request")
}
client, err := provider.GetClient(h.ctx(), h.getOIDCCallbackURL)
if err != nil {
return fmt.Errorf("OIDC initialization error: %w", err)
}
context := auth.GetOIDCClientContext(provider.InsecureSkipVerify)
token, err := client.Config().TokenSource(context, &oauth2.Token{RefreshToken: refreshToken}).Token()
if err != nil {
base.InfofCtx(h.ctx(), base.KeyAuth, "Unsuccessful token refresh: %v", err)
return base.HTTPErrorf(http.StatusInternalServerError, "Unable to refresh token.")
}
rawIDToken, ok := token.Extra(keyIDToken).(string)
if !ok {
return base.HTTPErrorf(http.StatusInternalServerError, "No id_token field in oauth2 token.")
}
base.InfofCtx(h.ctx(), base.KeyAuth, "Obtained token from Authorization Server: %v", rawIDToken)
username, sessionID, err := h.createSessionForTrustedIdToken(rawIDToken, provider)
if err != nil {
return err
}
refreshResponse := &OIDCTokenResponse{
IDToken: rawIDToken,
SessionID: sessionID,
Username: username,
}
if provider.IncludeAccessToken {
refreshResponse.AccessToken = token.AccessToken
refreshResponse.Expires = int(token.Expiry.Sub(time.Now()).Seconds())
refreshResponse.TokenType = token.TokenType
}
h.writeJSON(refreshResponse)
return nil
}
func (h *handler) createSessionForTrustedIdToken(rawIDToken string, provider *auth.OIDCProvider) (username string, sessionID string, err error) {
user, updates, tokenExpiryTime, err := h.db.Authenticator(h.ctx()).AuthenticateTrustedJWT(rawIDToken, provider, h.getOIDCCallbackURL)
if err != nil {
return "", "", err
}
if user == nil {
return "", "", ErrInvalidLogin
}
_, _, err = h.db.UpdatePrincipal(h.ctx(), &updates, true, true)
if err != nil {
return "", "", fmt.Errorf("failed to update user: %w", err)
}
if !provider.DisableSession {
sessionTTL := tokenExpiryTime.Sub(time.Now())
sessionID, err := h.makeSessionWithTTL(user, sessionTTL)
return user.Name(), sessionID, err
}
return user.Name(), "", nil
}
func (h *handler) getOIDCProvider(providerName string) (*auth.OIDCProvider, error) {
provider, err := h.db.GetOIDCProvider(providerName)
if provider == nil || err != nil {
return nil, base.HTTPErrorf(
http.StatusBadRequest, "OpenID Connect not configured for database %v", h.db.Name)
}
return provider, nil
}
// Builds the OIDC callback based on the current request. Used during OIDC Client lazy initialization.
// Need to pass providerName and isDefault for the requested provider to determine whether we need to append it to the callback URL or not.
func (h *handler) getOIDCCallbackURL(providerName string, isDefault bool) string {
// h.db not initialized at this point (checkPublicAuth) from validateAndWriteHeaders
// we'll have to pull it out of the router path rather than using h.db.Name
dbName := h.PathVar("db")
if dbName == "" {
// could be a keyspace-scoped request instead
dbName, _, _, _ = ParseKeyspace(h.PathVar("keyspace"))
}
if dbName == "" {
base.WarnfCtx(h.ctx(), "Can't calculate OIDC callback URL without DB in path.")
return ""
}
scheme := "http"
if h.rq.TLS != nil {
scheme = "https"
}
callbackURL := scheme + "://" + h.rq.Host + "/" + dbName + "/_oidc_callback"
if isDefault || providerName == "" {
return callbackURL
}
callbackURL, err := auth.SetURLQueryParam(callbackURL, auth.OIDCAuthProvider, providerName)
if err != nil {
base.WarnfCtx(h.ctx(), "Failed to add provider %q to OIDC callback URL (%s): %v", base.UD(providerName), callbackURL, err)
}
return callbackURL
}
// makeStateCookie creates a new state cookie with the specified value and Max-Age.
// Max-Age has precedence whilst determining the state cookie expiration even though
// both Expires and Max-Age are set.
func (h *handler) makeStateCookie(value string, maxAge int) *http.Cookie {
cookie := &http.Cookie{
Name: stateCookieName,
Value: value,
HttpOnly: true,
MaxAge: maxAge,
}
if h.rq.TLS != nil {
cookie.Secure = true
}
if maxAge > 0 {
cookie.Expires = time.Now().Add(time.Duration(maxAge) * time.Second)
}
base.AddDbPathToCookie(h.rq, cookie)
return cookie
}