From 373ec8932b6561dffc7ad567df44f3e4b8d6e29a Mon Sep 17 00:00:00 2001 From: xy Date: Thu, 3 Oct 2024 13:47:30 +0900 Subject: [PATCH] Resolve data race in JWT validator test --- appx/jwt_validator_test.go | 44 +++++++++++++++++++++++++++++++++++++- 1 file changed, 43 insertions(+), 1 deletion(-) diff --git a/appx/jwt_validator_test.go b/appx/jwt_validator_test.go index 27e2284..bdc4003 100644 --- a/appx/jwt_validator_test.go +++ b/appx/jwt_validator_test.go @@ -6,6 +6,7 @@ import ( "crypto/rsa" "encoding/json" "net/http" + "sync" "testing" "time" @@ -23,7 +24,10 @@ func TestMultiValidator(t *testing.T) { key := lo.Must(rsa.GenerateKey(rand.Reader, 2048)) httpmock.Activate() - defer httpmock.DeactivateAndReset() + t.Cleanup(func() { + httpmock.DeactivateAndReset() + }) + httpmock.RegisterResponder( http.MethodGet, "https://example.com/.well-known/openid-configuration", @@ -207,4 +211,42 @@ func TestMultiValidator(t *testing.T) { assert.Nil(t, res) assert.ErrorIs(t, err, context.Canceled) }) + + t.Run("mixed valid and invalid tokens", func(t *testing.T) { + v, err := NewJWTMultipleValidator([]JWTProvider{ + {ISS: "https://example.com/", AUD: []string{"a", "b"}, ALG: &jwt.SigningMethodRS256.Name}, + {ISS: "https://example2.com/", AUD: []string{"c"}, ALG: &jwt.SigningMethodRS256.Name}, + }) + assert.NoError(t, err) + + // Test with valid token + res, err := v.ValidateToken(context.Background(), tokenString) + assert.NoError(t, err) + assert.NotNil(t, res) + + // Test with invalid token + res, err = v.ValidateToken(context.Background(), "invalid.token") + assert.Error(t, err) + assert.Nil(t, res) + }) + + t.Run("concurrent validations", func(t *testing.T) { + v, err := NewJWTMultipleValidator([]JWTProvider{ + {ISS: "https://example.com/", AUD: []string{"a", "b"}, ALG: &jwt.SigningMethodRS256.Name}, + {ISS: "https://example2.com/", AUD: []string{"c"}, ALG: &jwt.SigningMethodRS256.Name}, + }) + assert.NoError(t, err) + + var wg sync.WaitGroup + for i := 0; i < 10; i++ { + wg.Add(1) + go func() { + defer wg.Done() + res, err := v.ValidateToken(context.Background(), tokenString) + assert.NoError(t, err) + assert.NotNil(t, res) + }() + } + wg.Wait() + }) }