Skip to content

Commit

Permalink
Add tests for change
Browse files Browse the repository at this point in the history
  • Loading branch information
kasugamirai committed Oct 1, 2024
1 parent 1c38c07 commit 583aa90
Show file tree
Hide file tree
Showing 2 changed files with 234 additions and 0 deletions.
86 changes: 86 additions & 0 deletions appx/jwt_validator_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -121,4 +121,90 @@ func TestMultiValidator(t *testing.T) {
res3, err := v.ValidateToken(context.Background(), tokenString3)
assert.ErrorIs(t, err, jwt2.ErrInvalidIssuer)
assert.Nil(t, res3)

t.Run("all validators fail", func(t *testing.T) {
invalidTokenString := "invalid.token.string"

res, err := v.ValidateToken(context.Background(), invalidTokenString)
assert.Error(t, err)
assert.Nil(t, res)

// Check if the error is a combination of multiple errors
var multiErr interface{ Unwrap() []error }
assert.ErrorAs(t, err, &multiErr)
errs := multiErr.Unwrap()
assert.Len(t, errs, 2)

// Check if both errors are related to invalid token
for _, e := range errs {
assert.Contains(t, e.Error(), "invalid JWT")
}
})

t.Run("first validator succeeds", func(t *testing.T) {
v, err := NewJWTMultipleValidator([]JWTProvider{
{ISS: "https://example.com/", AUD: []string{"a", "b"}, ALG: &jwt.SigningMethodRS256.Name},
{ISS: "https://example2.com/", AUD: []string{"c"}, ALG: &jwt.SigningMethodRS256.Name},
})
assert.NoError(t, err)

res, err := v.ValidateToken(context.Background(), tokenString)
assert.NoError(t, err)
assert.NotNil(t, res)
claims, ok := res.(*validator.ValidatedClaims)
assert.True(t, ok)
assert.Equal(t, "https://example.com/", claims.RegisteredClaims.Issuer)
})

t.Run("second validator succeeds", func(t *testing.T) {
v, err := NewJWTMultipleValidator([]JWTProvider{
{ISS: "https://example2.com/", AUD: []string{"c"}, ALG: &jwt.SigningMethodRS256.Name},
{ISS: "https://example.com/", AUD: []string{"a", "b"}, ALG: &jwt.SigningMethodRS256.Name},
})
assert.NoError(t, err)

res, err := v.ValidateToken(context.Background(), tokenString)
assert.NoError(t, err)
assert.NotNil(t, res)
claims, ok := res.(*validator.ValidatedClaims)
assert.True(t, ok)
assert.Equal(t, "https://example.com/", claims.RegisteredClaims.Issuer)
})

t.Run("all validators fail", func(t *testing.T) {
v, err := NewJWTMultipleValidator([]JWTProvider{
{ISS: "https://example2.com/", AUD: []string{"c"}, ALG: &jwt.SigningMethodRS256.Name},
{ISS: "https://example3.com/", AUD: []string{"d"}, ALG: &jwt.SigningMethodRS256.Name},
})
assert.NoError(t, err)

res, err := v.ValidateToken(context.Background(), tokenString)
assert.Error(t, err)
assert.Nil(t, res)

var multiErr interface{ Unwrap() []error }
assert.ErrorAs(t, err, &multiErr)
errs := multiErr.Unwrap()
assert.Len(t, errs, 2)

for _, e := range errs {
assert.Contains(t, e.Error(), "invalid JWT")
}
})

t.Run("context cancellation", func(t *testing.T) {
v, err := NewJWTMultipleValidator([]JWTProvider{
{ISS: "https://example.com/", AUD: []string{"a", "b"}, ALG: &jwt.SigningMethodRS256.Name},
{ISS: "https://example2.com/", AUD: []string{"c"}, ALG: &jwt.SigningMethodRS256.Name},
})
assert.NoError(t, err)

ctx, cancel := context.WithCancel(context.Background())
cancel()

res, err := v.ValidateToken(ctx, tokenString)
assert.Error(t, err)
assert.Nil(t, res)
assert.ErrorIs(t, err, context.Canceled)
})
}
148 changes: 148 additions & 0 deletions appx/tracer_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,148 @@
package appx

import (
"context"
"io"
"strings"
"testing"

"github.com/reearth/reearthx/log"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/mock"
"go.opentelemetry.io/otel"
sdktrace "go.opentelemetry.io/otel/sdk/trace"
)

// Mock for GCP exporter
type mockGCPExporter struct {
mock.Mock
}

func (m *mockGCPExporter) ExportSpans(ctx context.Context, spans []sdktrace.ReadOnlySpan) error {
args := m.Called(ctx, spans)
return args.Error(0)
}

func (m *mockGCPExporter) Shutdown(ctx context.Context) error {
args := m.Called(ctx)
return args.Error(0)
}

// Mock for Jaeger closer
type mockCloser struct {
mock.Mock
}

func (m *mockCloser) Close() error {
args := m.Called()
return args.Error(0)
}

type testLogWriter struct {
strings.Builder
}

func (w *testLogWriter) Write(p []byte) (int, error) {
return w.Builder.Write(p)
}

func TestInitTracer(t *testing.T) {
// Create function variables
var testInitGCPTracer func(ctx context.Context, conf *TracerConfig)
var testInitJaegerTracer func(conf *TracerConfig) io.Closer

// Create a test wrapper for InitTracer that uses the function variables
testInitTracer := func(ctx context.Context, conf *TracerConfig) io.Closer {
if conf.Tracer == TRACER_GCP {
testInitGCPTracer(ctx, conf)
return nil
} else if conf.Tracer == TRACER_JAEGER {
return testInitJaegerTracer(conf)
}
return nil
}

tests := []struct {
name string
config *TracerConfig
setup func()
expected io.Closer
}{
{
name: "GCP Tracer",
config: &TracerConfig{
Name: "test-gcp",
Tracer: TRACER_GCP,
TracerSample: 0.5,
},
setup: func() {
testInitGCPTracer = func(ctx context.Context, conf *TracerConfig) {
// Mock the GCP tracer initialization
mockExporter := &mockGCPExporter{}
tp := sdktrace.NewTracerProvider(sdktrace.WithSyncer(mockExporter))
otel.SetTracerProvider(tp)
log.Infofc(ctx, "tracer: initialized cloud trace with sample fraction: %g", conf.TracerSample)
}
},
expected: nil,
},
{
name: "Jaeger Tracer",
config: &TracerConfig{
Name: "test-jaeger",
Tracer: TRACER_JAEGER,
TracerSample: 0.5,
},
setup: func() {
testInitJaegerTracer = func(conf *TracerConfig) io.Closer {
// Mock the Jaeger tracer initialization
mockCloser := &mockCloser{}
mockCloser.On("Close").Return(nil)
log.Infof("tracer: initialized jaeger tracer with sample fraction: %g", conf.TracerSample)
return mockCloser
}
},
expected: &mockCloser{},
},
{
name: "Unknown Tracer",
config: &TracerConfig{
Name: "test-unknown",
Tracer: "unknown",
TracerSample: 0.5,
},
setup: func() {},
expected: nil,
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
tt.setup()

// Capture log output
logWriter := &testLogWriter{}
log.SetOutput(logWriter)
defer log.SetOutput(nil)

ctx := context.Background()
closer := testInitTracer(ctx, tt.config)

if tt.expected == nil {
assert.Nil(t, closer)
} else {
assert.NotNil(t, closer)
assert.IsType(t, tt.expected, closer)
}

// Check if the log output contains the expected message
logOutput := logWriter.String()
expectedLogMessage := "tracer: initialized"
if tt.config.Tracer != "unknown" {
assert.Contains(t, logOutput, expectedLogMessage)
} else {
assert.Empty(t, logOutput)
}
})
}
}

0 comments on commit 583aa90

Please sign in to comment.