Skip to content

Commit

Permalink
Enable user to configure custom Access/Refresh token (#5087)
Browse files Browse the repository at this point in the history
* addition of access/refresh token configuration
  • Loading branch information
chavakula authored Feb 21, 2025
1 parent d0b5f73 commit ff5db96
Show file tree
Hide file tree
Showing 9 changed files with 152 additions and 23 deletions.
5 changes: 5 additions & 0 deletions CHANGELOG-6.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,11 @@ All notable changes to this project will be documented in this file.
The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/)
and this project adheres to [Semantic
Versioning](http://semver.org/spec/v2.0.0.html).
## [6.13.0] - Unreleased

### Added
- Added `access-token-expiry` (in minutes) backend configuration variable to control expiry of access token.
- Added `refresh-token-expiry` (in minutes) backend configuration variable to control expiry of refresh token.

## [6.12.0] - 2024-11-13

Expand Down
35 changes: 30 additions & 5 deletions backend/api/authentication.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@ import (
"context"
"errors"
"fmt"

corev2 "github.com/sensu/core/v2"
"time"

"github.com/sensu/sensu-go/backend/authentication"
"github.com/sensu/sensu-go/backend/authentication/jwt"
Expand Down Expand Up @@ -51,8 +51,14 @@ func (a *AuthenticationClient) CreateAccessToken(ctx context.Context, username,
claims.Issuer = issuer.(string)
}

// append configured access token expiry to claims
var accessTokenExpiry time.Duration
if accessTokenExp := ctx.Value("accessTokenExpiry"); accessTokenExp != nil {
accessTokenExpiry = accessTokenExp.(time.Duration)
}

// Create an access token and its signed version
_, tokenString, err := jwt.AccessToken(claims)
_, tokenString, err := jwt.AccessToken(claims, jwt.WithAccessTokenExpiry(accessTokenExpiry))
if err != nil {
return nil, fmt.Errorf("error creating access token: %s", err)
}
Expand All @@ -62,7 +68,14 @@ func (a *AuthenticationClient) CreateAccessToken(ctx context.Context, username,
StandardClaims: corev2.StandardClaims(claims.Subject),
SessionID: sessionID,
}
refreshToken, refreshTokenString, err := jwt.RefreshToken(refreshClaims)

// append configured refresh token expiry to claims
var refreshTokenExpiry time.Duration
if refreshTokenExp := ctx.Value("refreshTokenExpiry"); refreshTokenExp != nil {
refreshTokenExpiry = refreshTokenExp.(time.Duration)
}

refreshToken, refreshTokenString, err := jwt.RefreshToken(refreshClaims, jwt.WithRefreshTokenExpiry(refreshTokenExpiry))
if err != nil {
return nil, fmt.Errorf("error creating refresh token: %s", err)
}
Expand Down Expand Up @@ -198,18 +211,30 @@ func (a *AuthenticationClient) RefreshAccessToken(ctx context.Context) (*corev2.
claims.Issuer = issuer.(string)
}

// append configured access token expiry to claims
var accessTokenExpiry time.Duration
if accessTokenExp := ctx.Value("accessTokenExpiry"); accessTokenExp != nil {
accessTokenExpiry = accessTokenExp.(time.Duration)
}

// Issue a new access token
_, newAccessTokenString, err := jwt.AccessToken(claims)
_, newAccessTokenString, err := jwt.AccessToken(claims, jwt.WithAccessTokenExpiry(accessTokenExpiry))
if err != nil {
return nil, err
}

// append configured refresh token expiry to claims
var refreshTokenExpiry time.Duration
if refreshTokenExp := ctx.Value("refreshTokenExpiry"); refreshTokenExp != nil {
refreshTokenExpiry = refreshTokenExp.(time.Duration)
}

// Create a new refresh token, carrying over the session ID
newRefreshClaims := &corev2.Claims{
StandardClaims: corev2.StandardClaims(claims.Subject),
SessionID: sessionID,
}
newRefreshToken, newRefreshTokenString, err := jwt.RefreshToken(newRefreshClaims)
newRefreshToken, newRefreshTokenString, err := jwt.RefreshToken(newRefreshClaims, jwt.WithRefreshTokenExpiry(refreshTokenExpiry))
if err != nil {
return nil, fmt.Errorf("error creating refresh token: %s", err)
}
Expand Down
14 changes: 13 additions & 1 deletion backend/api/authentication_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"context"
"errors"
"testing"
"time"

corev2 "github.com/sensu/core/v2"
"github.com/sensu/sensu-go/backend/authentication"
Expand Down Expand Up @@ -33,6 +34,10 @@ func contextWithClaims(claims *corev2.Claims) context.Context {
ctx := context.Background()
ctx = context.WithValue(ctx, corev2.AccessTokenClaims, claims)
ctx = context.WithValue(ctx, corev2.RefreshTokenClaims, refreshClaims)

ctx = context.WithValue(ctx, "accessTokenExpiry", 5*time.Minute)
ctx = context.WithValue(ctx, "refreshTokenExpiry", 12*time.Hour)

return ctx
}

Expand Down Expand Up @@ -205,7 +210,14 @@ func TestRefreshAccessToken(t *testing.T) {
Authenticator: defaultAuth,
Context: func(claims *corev2.Claims) (context.Context, string) {
ctx := contextWithClaims(claims)
refreshToken, refreshTokenString, _ := jwt.RefreshToken(ctx.Value(corev2.RefreshTokenClaims).(*corev2.Claims))

// append configured access token expiry to claims
var refreshTokenExpiry time.Duration
if refreshTokenExp := ctx.Value("refreshTokenExpiry"); refreshTokenExp != nil {
refreshTokenExpiry = refreshTokenExp.(time.Duration)
}

refreshToken, refreshTokenString, _ := jwt.RefreshToken(ctx.Value(corev2.RefreshTokenClaims).(*corev2.Claims), jwt.WithRefreshTokenExpiry(refreshTokenExpiry))
refreshTokenClaims, _ := jwt.GetClaims(refreshToken)
ctx = context.WithValue(ctx, corev2.RefreshTokenString, refreshTokenString)
return ctx, refreshTokenClaims.Id
Expand Down
9 changes: 8 additions & 1 deletion backend/apid/apid.go
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,9 @@ type APId struct {

serveWaitTime time.Duration
ready func()

AccessTokenExpiry time.Duration
RefreshTokenExpiry time.Duration
}

// Option is a functional option.
Expand All @@ -81,6 +84,8 @@ type Config struct {
ClusterVersion string
GraphQLService *graphql.Service
HealthRouter *routers.HealthRouter
AccessTokenExpiry time.Duration
RefreshTokenExpiry time.Duration
}

// New creates a new APId.
Expand All @@ -102,6 +107,8 @@ func New(c Config, opts ...Option) (*APId, error) {
clusterVersion: c.ClusterVersion,
RequestLimit: c.RequestLimit,
serveWaitTime: c.ServeWaitTime,
AccessTokenExpiry: c.AccessTokenExpiry,
RefreshTokenExpiry: c.RefreshTokenExpiry,
}

// prepare TLS config
Expand Down Expand Up @@ -174,7 +181,7 @@ func AuthenticationSubrouter(router *mux.Router, cfg Config) *mux.Router {
)

mountRouters(subrouter,
routers.NewAuthenticationRouter(cfg.Store, cfg.Authenticator),
routers.NewAuthenticationRouter(cfg.Store, cfg.Authenticator, cfg.AccessTokenExpiry, cfg.RefreshTokenExpiry),
)

return subrouter
Expand Down
19 changes: 15 additions & 4 deletions backend/apid/routers/authentication.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"context"
"encoding/json"
"net/http"
"time"

"github.com/sensu/sensu-go/backend/authentication/jwt"

Expand All @@ -17,13 +18,15 @@ import (

// AuthenticationRouter handles authentication related requests
type AuthenticationRouter struct {
store store.Store
authenticator *authentication.Authenticator
store store.Store
authenticator *authentication.Authenticator
accessTokenExpiry time.Duration
refreshTokenExpiry time.Duration
}

// NewAuthenticationRouter instantiates new router.
func NewAuthenticationRouter(store store.Store, authenticator *authentication.Authenticator) *AuthenticationRouter {
return &AuthenticationRouter{store: store, authenticator: authenticator}
func NewAuthenticationRouter(store store.Store, authenticator *authentication.Authenticator, accessTokenExpiry time.Duration, refreshTokenExpiry time.Duration) *AuthenticationRouter {
return &AuthenticationRouter{store: store, authenticator: authenticator, accessTokenExpiry: accessTokenExpiry, refreshTokenExpiry: refreshTokenExpiry}
}

// Mount the authentication routes on given mux.Router.
Expand All @@ -47,6 +50,10 @@ func (a *AuthenticationRouter) login(w http.ResponseWriter, r *http.Request) {
// issuer URL
ctx := context.WithValue(r.Context(), jwt.IssuerURLKey, issuerURL(r))

// Not very efficient, but acceptable for simple use cases, ideally we should create a struct and pass the struct
ctx = context.WithValue(ctx, "accessTokenExpiry", a.accessTokenExpiry)
ctx = context.WithValue(ctx, "refreshTokenExpiry", a.refreshTokenExpiry)

client := api.NewAuthenticationClient(a.authenticator, a.store)
tokens, err := client.CreateAccessToken(ctx, username, password)
if err != nil {
Expand Down Expand Up @@ -106,6 +113,10 @@ func (a *AuthenticationRouter) token(w http.ResponseWriter, r *http.Request) {
// issuer URL
ctx := context.WithValue(r.Context(), jwt.IssuerURLKey, issuerURL(r))

// Not very efficient, but acceptable for simple use cases, ideally we should create a struct and pass the struct
ctx = context.WithValue(ctx, "accessTokenExpiry", a.accessTokenExpiry)
ctx = context.WithValue(ctx, "refreshTokenExpiry", a.refreshTokenExpiry)

tokens, err := client.RefreshAccessToken(ctx)
if err != nil {
if err == corev2.ErrInvalidToken {
Expand Down
56 changes: 52 additions & 4 deletions backend/authentication/jwt/jwt.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,30 @@ const (
IssuerURLKey key = iota
)

// ExpiryOptions Functional Options Pattern
// ExpiryOptions: Define a struct for optional parameters.
type ExpiryOptions struct {
RefreshTokenExpiry time.Duration
AccessTokenExpiry time.Duration
}

// ExpiryOption Define a functional option type.
type ExpiryOption func(options *ExpiryOptions)

// WithRefreshTokenExpiry for setting refresh token expiry
func WithRefreshTokenExpiry(expiry time.Duration) ExpiryOption {
return func(o *ExpiryOptions) {
o.RefreshTokenExpiry = expiry
}
}

// WithAccessTokenExpiry for setting access token expiry
func WithAccessTokenExpiry(expiry time.Duration) ExpiryOption {
return func(o *ExpiryOptions) {
o.AccessTokenExpiry = expiry
}
}

var (
DefaultAccessTokenLifespan = 5 * time.Minute
defaultRefreshTokenLifespan = 12 * time.Hour
Expand All @@ -49,16 +73,27 @@ func init() {

// AccessToken creates a new access token and returns it in both JWT and
// signed format, along with any error
func AccessToken(claims *corev2.Claims) (*jwt.Token, string, error) {
func AccessToken(claims *corev2.Claims, options ...ExpiryOption) (*jwt.Token, string, error) {
// Create a unique identifier for the token
jti, err := GenJTI()
if err != nil {
return nil, "", err
}
claims.Id = jti

// Default options.
opts := ExpiryOptions{
RefreshTokenExpiry: defaultRefreshTokenLifespan,
AccessTokenExpiry: DefaultAccessTokenLifespan,
}

// Apply functional options.
for _, option := range options {
option(&opts)
}

// Add an expiration to the token
claims.ExpiresAt = time.Now().Add(DefaultAccessTokenLifespan).Unix()
claims.ExpiresAt = time.Now().Add(opts.AccessTokenExpiry).Unix()

token := jwt.NewWithClaims(signingMethod, claims)

Expand Down Expand Up @@ -246,18 +281,31 @@ func parseToken(tokenString string) (*jwt.Token, error) {
}

// RefreshToken returns a refresh token for a specific user
func RefreshToken(claims *corev2.Claims) (*jwt.Token, string, error) {
func RefreshToken(claims *corev2.Claims, options ...ExpiryOption) (*jwt.Token, string, error) {
// Create a unique identifier for the token
jti, err := GenJTI()
if err != nil {
return nil, "", err
}
claims.Id = jti

// Default options.
opts := ExpiryOptions{
RefreshTokenExpiry: defaultRefreshTokenLifespan,
AccessTokenExpiry: DefaultAccessTokenLifespan,
}

// Apply functional options.
for _, option := range options {
option(&opts)
}

// Add an expiration to the token
claims.ExpiresAt = time.Now().Add(opts.RefreshTokenExpiry).Unix()

// Add issuance and expiration timestamps to the token
now := time.Now()
claims.IssuedAt = now.Unix()
claims.ExpiresAt = now.Add(defaultRefreshTokenLifespan).Unix()

token := jwt.NewWithClaims(signingMethod, claims)

Expand Down
2 changes: 2 additions & 0 deletions backend/backend.go
Original file line number Diff line number Diff line change
Expand Up @@ -632,6 +632,8 @@ func Initialize(ctx context.Context, config *Config) (*Backend, error) {
ClusterVersion: clusterVersion,
GraphQLService: b.GraphQLService,
HealthRouter: b.HealthRouter,
AccessTokenExpiry: config.AccessTokenExpiry,
RefreshTokenExpiry: config.RefreshTokenExpiry,
}
newApi, err := apid.New(b.APIDConfig)
if err != nil {
Expand Down
31 changes: 23 additions & 8 deletions backend/cmd/start.go
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,10 @@ const (
flagMaxSilencedExpiryTimeAllowed = "max-silenced-expiry-time-allowed"
flagDefaultSilencedExpiryTime = "default-silenced-expiry-time"

// access token and refresh token expiry time
flagAccessTokenExpiry = "access-token-expiry"
flagRefreshTokenExpiry = "refresh-token-expiry"

// Etcd flag constants
flagEtcdClientURLs = "etcd-client-urls"
flagEtcdListenClientURLs = "etcd-listen-client-urls"
Expand Down Expand Up @@ -293,6 +297,9 @@ func StartCommand(initialize InitializeFunc) *cobra.Command {
EventLogBufferWait: viper.GetDuration(flagEventLogBufferWait),
EventLogFile: viper.GetString(flagEventLogFile),
EventLogParallelEncoders: viper.GetBool(flagEventLogParallelEncoders),

AccessTokenExpiry: viper.GetDuration(flagAccessTokenExpiry),
RefreshTokenExpiry: viper.GetDuration(flagRefreshTokenExpiry),
}

if flag := cmd.Flags().Lookup(flagLabels); flag != nil && flag.Changed {
Expand Down Expand Up @@ -455,12 +462,16 @@ func handleConfig(cmd *cobra.Command, arguments []string, server bool) error {
viper.SetDefault(flagEventLogBufferSize, 100000)
viper.SetDefault(flagEventLogFile, "")
viper.SetDefault(flagEventLogParallelEncoders, false)

// default silenced value are set for 1 day = 1440m
viper.SetDefault(flagMaxSilencedExpiryTimeAllowed, "1440m")
viper.SetDefault(flagDefaultSilencedExpiryTime, "1440m")
}

// default silenced value are set for 1 day = 1440m
viper.SetDefault(flagMaxSilencedExpiryTimeAllowed, "1440m")
viper.SetDefault(flagDefaultSilencedExpiryTime, "1440m")

// Access/Refresh token default expiry values
viper.SetDefault(flagAccessTokenExpiry, "5m")
viper.SetDefault(flagRefreshTokenExpiry, "720m")

// Etcd defaults
viper.SetDefault(flagEtcdAdvertiseClientURLs, defaultEtcdAdvertiseClientURL)
viper.SetDefault(flagEtcdListenClientURLs, defaultEtcdClientURL)
Expand Down Expand Up @@ -552,6 +563,14 @@ func flagSet(server bool) *pflag.FlagSet {
flagSet.String(flagEtcdClientURLs, viper.GetString(flagEtcdClientURLs), "client URLs to use when operating as an etcd client")
_ = flagSet.SetAnnotation(flagEtcdClientURLs, "categories", []string{"store"})

// silenced configuration flags
flagSet.Duration(flagDefaultSilencedExpiryTime, viper.GetDuration(flagDefaultSilencedExpiryTime), "Default expiry time for silenced if not set in minutes")
flagSet.Duration(flagMaxSilencedExpiryTimeAllowed, viper.GetDuration(flagMaxSilencedExpiryTimeAllowed), "Maximum expiry time allowed for silenced in minutes")

// Access/Token configuration flags
flagSet.Duration(flagAccessTokenExpiry, viper.GetDuration(flagAccessTokenExpiry), "Set Access Token expiry in minutes")
flagSet.Duration(flagRefreshTokenExpiry, viper.GetDuration(flagRefreshTokenExpiry), "Set Refresh Token expiry in minutes")

if server {
// Main Flags
flagSet.String(flagAgentHost, viper.GetString(flagAgentHost), "agent listener host")
Expand Down Expand Up @@ -594,10 +613,6 @@ func flagSet(server bool) *pflag.FlagSet {
flagSet.Duration(flagPlatformMetricsLoggingInterval, viper.GetDuration(flagPlatformMetricsLoggingInterval), "platform metrics logging interval")
flagSet.String(flagPlatformMetricsLogFile, viper.GetString(flagPlatformMetricsLogFile), "platform metrics log file path")

// silenced configuration flags
flagSet.Duration(flagDefaultSilencedExpiryTime, viper.GetDuration(flagDefaultSilencedExpiryTime), "Default expiry time for silenced if not set in minutes")
flagSet.Duration(flagMaxSilencedExpiryTimeAllowed, viper.GetDuration(flagMaxSilencedExpiryTimeAllowed), "Maximum expiry time allowed for silenced in minutes")

// Etcd server flags
flagSet.StringSlice(flagEtcdPeerURLs, viper.GetStringSlice(flagEtcdPeerURLs), "list of URLs to listen on for peer traffic")
_ = flagSet.SetAnnotation(flagEtcdPeerURLs, "categories", []string{"store"})
Expand Down
4 changes: 4 additions & 0 deletions backend/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -136,4 +136,8 @@ type Config struct {
// expiry setting for silences
DefaultSilencedExpiryTime time.Duration
MaxSilencedExpiryTimeAllowed time.Duration

// Access/Refresh Token Expiry in Minutes
AccessTokenExpiry time.Duration
RefreshTokenExpiry time.Duration
}

0 comments on commit ff5db96

Please sign in to comment.