Skip to content

Commit

Permalink
Adding refreshNow functionality to the cache (#77)
Browse files Browse the repository at this point in the history
  • Loading branch information
wlewis4321 authored Apr 30, 2024
1 parent 0a82475 commit acd923a
Show file tree
Hide file tree
Showing 5 changed files with 112 additions and 1 deletion.
10 changes: 10 additions & 0 deletions secretcache/cache.go
Original file line number Diff line number Diff line change
Expand Up @@ -154,3 +154,13 @@ func (c *Cache) GetSecretBinaryWithStageWithContext(ctx context.Context, secretI

return getSecretValueOutput.SecretBinary, nil
}

// Method to force the refresh of a secret inside the cache
func (c *Cache) RefreshNow(secretId string) {
c.RefreshNowWithContext(aws.BackgroundContext(), secretId)
}

func (c *Cache) RefreshNowWithContext(ctx context.Context, secretId string) {
secretCacheItem := c.getCachedSecret(secretId)
secretCacheItem.refreshNow(ctx)
}
17 changes: 17 additions & 0 deletions secretcache/cacheItem.go
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,23 @@ func (ci *secretCacheItem) getVersion(versionStage string) (*cacheVersion, bool)
return secretCacheVersion, true
}

// refresh the cached object on demand
func (ci *secretCacheItem) refreshNow(ctx context.Context) {
ci.refreshNeeded = true
// Generate a random number to have a sleep jitter to not get stuck in a retry loop
sleep := rand.Int63n((forceRefreshJitterSleep+1)-(forceRefreshJitterSleep/2)+1) + (forceRefreshJitterSleep / 2)

if ci.err != nil {
exceptionSleep := ci.nextRefreshTime - time.Now().UnixNano()
if exceptionSleep > sleep {
sleep = exceptionSleep
}
}

time.Sleep(time.Millisecond * time.Duration(sleep))
ci.refresh(ctx)
}

// refresh the cached object when needed.
func (ci *secretCacheItem) refresh(ctx context.Context) {
if !ci.isRefreshNeeded() {
Expand Down
1 change: 1 addition & 0 deletions secretcache/cacheObject.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ const (
exceptionRetryDelayBase = 1
exceptionRetryGrowthFactor = 2
exceptionRetryDelayMax = 3600
forceRefreshJitterSleep = 5000
)

// Base cache object for common properties.
Expand Down
38 changes: 38 additions & 0 deletions secretcache/cacheObjects_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,44 @@ func TestMaxCacheTTL(t *testing.T) {
}
}

func TestRefreshNow(t *testing.T) {
mockClient := dummyClient{}

cacheItem := secretCacheItem{
cacheObject: &cacheObject{
secretId: "dummy-secret-name",
client: &mockClient,
data: &secretsmanager.DescribeSecretOutput{
ARN: getStrPtr("dummy-arn"),
Name: getStrPtr("dummy-name"),
Description: getStrPtr("dummy-description"),
},
},
}

config := CacheConfig{CacheItemTTL: 0}
cacheItem.config = config
cacheItem.refresh(aws.BackgroundContext())
refreshTime := cacheItem.nextRefreshTime

cacheItem.refresh(aws.BackgroundContext())

if refreshTime != cacheItem.nextRefreshTime {
t.Fatalf("Expected nextRefreshTime to be same")
}

cacheItem.refreshNow(aws.BackgroundContext())

if cacheItem.nextRefreshTime == refreshTime {
t.Fatalf("Expected nextRefreshTime to be different")
}

if cacheItem.errorCount > 0 {
t.Fatalf("Expected errorCount to be 0")
}

}

type dummyClient struct {
secretsmanageriface.SecretsManagerAPI
}
Expand Down
47 changes: 46 additions & 1 deletion secretcache/cache_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,10 @@ package secretcache_test
import (
"bytes"
"errors"
"github.com/aws/aws-secretsmanager-caching-go/secretcache"
"testing"
"time"

"github.com/aws/aws-secretsmanager-caching-go/secretcache"

"github.com/aws/aws-sdk-go/service/secretsmanager"
)
Expand Down Expand Up @@ -356,6 +358,49 @@ func TestGetSecretBinaryMultipleNotFound(t *testing.T) {
}
}

func TestRefreshNow(t *testing.T) {
mockClient, secretId, secretString := newMockedClientWithDummyResults()
secretCache, _ := secretcache.New(
func(c *secretcache.Cache) { c.Client = &mockClient },
func(c *secretcache.Cache) { c.CacheConfig.CacheItemTTL = time.Hour.Nanoseconds() },
)
originalSecret, err := secretCache.GetSecretString(secretId)
if err != nil {
t.Fatalf("Unexpected error - %s", err.Error())
}
if originalSecret != secretString {
t.Fatalf("Expected and result secret string are different - \"%s\", \"%s\"", secretString, originalSecret)
}

_, _ = secretCache.GetSecretString(secretId)

if mockClient.DescribeSecretCallCount != 1 {
t.Fatalf("Expected a single call to DescribeSecret API, got %d", mockClient.DescribeSecretCallCount)
}

secretCache.RefreshNow(secretId)
refreshedSecret, err := secretCache.GetSecretString(secretId)

if err != nil {
t.Fatalf("Unexpected error - %s", err.Error())
}

if refreshedSecret != secretString {
t.Fatalf("Expected and result secret string are different - \"%s\", \"%s\"", secretString, refreshedSecret)
}

if mockClient.DescribeSecretCallCount != 2 {
t.Fatalf("Expected two calls to DescribeSecret API, got %d", mockClient.DescribeSecretCallCount)
}

_, _ = secretCache.GetSecretString(secretId)

if mockClient.DescribeSecretCallCount != 2 {
t.Fatalf("Expected two calls to DescribeSecret API, got %d", mockClient.DescribeSecretCallCount)
}

}

func TestGetSecretVersionStageEmpty(t *testing.T) {
mockClient, _, secretString := newMockedClientWithDummyResults()

Expand Down

0 comments on commit acd923a

Please sign in to comment.