Skip to content

Commit

Permalink
refactor(appx): optimize JWT token validation with concurrent validators
Browse files Browse the repository at this point in the history
  • Loading branch information
kasugamirai committed Sep 30, 2024
1 parent 94879db commit 59549b7
Show file tree
Hide file tree
Showing 2 changed files with 38 additions and 17 deletions.
49 changes: 35 additions & 14 deletions appx/jwt_validator.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"context"
"errors"
"fmt"
"sync"

"github.com/auth0/go-jwt-middleware/v2/validator"
"github.com/reearth/reearthx/log"
Expand All @@ -29,7 +30,7 @@ func NewJWTValidatorWithError(
audience []string,
opts ...validator.Option,
) (*JWTValidatorWithError, error) {
validator, err := validator.New(
v, err := validator.New(
keyFunc,
signatureAlgorithm,
issuerURL,
Expand All @@ -40,7 +41,7 @@ func NewJWTValidatorWithError(
return nil, err
}
return &JWTValidatorWithError{
validator: validator,
validator: v,
iss: issuerURL,
aud: slices.Clone(audience),
}, nil
Expand All @@ -49,9 +50,9 @@ func NewJWTValidatorWithError(
func (v *JWTValidatorWithError) ValidateToken(ctx context.Context, token string) (interface{}, error) {
res, err := v.validator.ValidateToken(ctx, token)
if err != nil {
err = fmt.Errorf("invalid JWT: iss=%s aud=%v err=%w", v.iss, v.aud, err)
return nil, fmt.Errorf("invalid JWT: iss=%s aud=%v err=%w", v.iss, v.aud, err)
}
return res, err
return res, nil
}

type JWTMultipleValidator []JWTValidator
Expand All @@ -62,20 +63,40 @@ func NewJWTMultipleValidator(providers []JWTProvider) (JWTMultipleValidator, err
})
}

// ValidateToken Trys to validate the token with each validator
// ValidateToken tries to validate the token with each validator concurrently
// NOTE: the last validation error only is returned
func (mv JWTMultipleValidator) ValidateToken(ctx context.Context, tokenString string) (res interface{}, err error) {
func (mv JWTMultipleValidator) ValidateToken(ctx context.Context, tokenString string) (interface{}, error) {
type result struct {
res interface{}
err error
}

resultChan := make(chan result, len(mv))
var wg sync.WaitGroup

for _, v := range mv {
var err2 error
res, err2 = v.ValidateToken(ctx, tokenString)
if err2 == nil {
err = nil
return
wg.Add(1)
go func(validator JWTValidator) {
defer wg.Done()
res, err := validator.ValidateToken(ctx, tokenString)
resultChan <- result{res, err}
}(v)
}

go func() {
wg.Wait()
close(resultChan)
}()

var lastErr error
for r := range resultChan {
if r.err == nil {
return r.res, nil
}
err = errors.Join(err, err2)
lastErr = errors.Join(lastErr, r.err)
}

log.Debugfc(ctx, "auth: invalid JWT token: %s", tokenString)
log.Errorfc(ctx, "auth: invalid JWT token: %v", err)
return
log.Errorfc(ctx, "auth: invalid JWT token: %v", lastErr)
return nil, lastErr
}
6 changes: 3 additions & 3 deletions appx/tracer.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ type TracerConfig struct {
TracerSample float64
}

func InitTracer(ctx context.Context, conf TracerConfig) io.Closer {
func InitTracer(ctx context.Context, conf *TracerConfig) io.Closer {
if conf.Tracer == TRACER_GCP {
initGCPTracer(ctx, conf)
} else if conf.Tracer == TRACER_JAEGER {
Expand All @@ -34,7 +34,7 @@ func InitTracer(ctx context.Context, conf TracerConfig) io.Closer {
return nil
}

func initGCPTracer(ctx context.Context, conf TracerConfig) {
func initGCPTracer(ctx context.Context, conf *TracerConfig) {
exporter, err := texporter.New()
if err != nil {
log.Fatalc(ctx, err)
Expand All @@ -50,7 +50,7 @@ func initGCPTracer(ctx context.Context, conf TracerConfig) {
log.Infofc(ctx, "tracer: initialized cloud trace with sample fraction: %g", conf.TracerSample)
}

func initJaegerTracer(conf TracerConfig) io.Closer {
func initJaegerTracer(conf *TracerConfig) io.Closer {
cfg := jaegercfg.Configuration{
Sampler: &jaegercfg.SamplerConfig{
Type: jaeger.SamplerTypeConst,
Expand Down

0 comments on commit 59549b7

Please sign in to comment.