Skip to content

Commit

Permalink
Merge pull request #2 from brexhq/context-support
Browse files Browse the repository at this point in the history
Update Cache.Get* methods to accept a context
  • Loading branch information
tyen-brex authored Jun 29, 2022
2 parents 36dcefd + 2766c98 commit 9af7101
Show file tree
Hide file tree
Showing 8 changed files with 61 additions and 59 deletions.
8 changes: 4 additions & 4 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -33,15 +33,15 @@ package main

import (
"github.com/aws/aws-lambda-go/lambda"
"github.com/aws/aws-secretsmanager-caching-go/secretcache"
"github.com/aws/aws-secretsmanager-caching-go/v2/secretcache"
)

var(
secretCache, _ = secretcache.New()
)

func HandleRequest(secretId string) string {
result, _ := secretCache.GetSecretString(secretId)
func HandleRequest(ctx context.Context, secretId string) string {
result, _ := secretCache.GetSecretString(ctx, secretId)
// Use secret to connect to secured resource.
return "Success"
}
Expand Down Expand Up @@ -85,4 +85,4 @@ We use GitHub issues for tracking bugs and caching library feature requests and

## License

This library is licensed under the Apache 2.0 License.
This library is licensed under the Apache 2.0 License.
20 changes: 10 additions & 10 deletions scintegtests/integration_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -160,7 +160,7 @@ func integTest_getSecretBinary(t *testing.T, api *secretsmanager.Client) string
return ""
}

resultBinary, err := cache.GetSecretBinary(*createResult.ARN)
resultBinary, err := cache.GetSecretBinary(context.Background(), *createResult.ARN)

if err != nil {
t.Error(err)
Expand Down Expand Up @@ -200,7 +200,7 @@ func integTest_getSecretBinaryWithStage(t *testing.T, api *secretsmanager.Client
return *createResult.ARN
}

resultBinary, err := cache.GetSecretBinaryWithStage(*createResult.ARN, "AWSPREVIOUS")
resultBinary, err := cache.GetSecretBinaryWithStage(context.Background(), *createResult.ARN, "AWSPREVIOUS")

if err != nil {
t.Error(err)
Expand All @@ -211,7 +211,7 @@ func integTest_getSecretBinaryWithStage(t *testing.T, api *secretsmanager.Client
t.Error("Expected and result binary not the same")
}

resultBinary, err = cache.GetSecretBinaryWithStage(*createResult.ARN, "AWSCURRENT")
resultBinary, err = cache.GetSecretBinaryWithStage(context.Background(), *createResult.ARN, "AWSCURRENT")

if err != nil {
t.Error(err)
Expand All @@ -237,7 +237,7 @@ func integTest_getSecretString(t *testing.T, api *secretsmanager.Client) string
return ""
}

resultString, err := cache.GetSecretString(*createResult.ARN)
resultString, err := cache.GetSecretString(context.Background(), *createResult.ARN)

if err != nil {
t.Error(err)
Expand Down Expand Up @@ -277,7 +277,7 @@ func integTest_getSecretStringWithStage(t *testing.T, api *secretsmanager.Client
return *createResult.ARN
}

resultString, err := cache.GetSecretStringWithStage(*createResult.ARN, "AWSPREVIOUS")
resultString, err := cache.GetSecretStringWithStage(context.Background(), *createResult.ARN, "AWSPREVIOUS")

if err != nil {
t.Error(err)
Expand All @@ -288,7 +288,7 @@ func integTest_getSecretStringWithStage(t *testing.T, api *secretsmanager.Client
t.Errorf("Expected and result secret string are different - \"%s\", \"%s\"", secretString, resultString)
}

resultString, err = cache.GetSecretStringWithStage(*createResult.ARN, "AWSCURRENT")
resultString, err = cache.GetSecretStringWithStage(context.Background(), *createResult.ARN, "AWSCURRENT")

if err != nil {
t.Error(err)
Expand Down Expand Up @@ -317,7 +317,7 @@ func integTest_getSecretStringWithTTL(t *testing.T, api *secretsmanager.Client)
return ""
}

resultString, err := cache.GetSecretString(*createResult.ARN)
resultString, err := cache.GetSecretString(context.Background(), *createResult.ARN)

if err != nil {
t.Error(err)
Expand All @@ -342,7 +342,7 @@ func integTest_getSecretStringWithTTL(t *testing.T, api *secretsmanager.Client)
return *createResult.ARN
}

resultString, err = cache.GetSecretString(*createResult.ARN)
resultString, err = cache.GetSecretString(context.Background(), *createResult.ARN)

if err != nil {
t.Error(err)
Expand All @@ -356,7 +356,7 @@ func integTest_getSecretStringWithTTL(t *testing.T, api *secretsmanager.Client)

time.Sleep(time.Nanosecond * time.Duration(ttlNanoSeconds))

resultString, err = cache.GetSecretString(*createResult.ARN)
resultString, err = cache.GetSecretString(context.Background(), *createResult.ARN)
if updatedSecretString != resultString {
t.Errorf("Expected cached secret to be same as updated version - \"%s\", \"%s\"", resultString, updatedSecretString)
return *createResult.ARN
Expand All @@ -371,7 +371,7 @@ func integTest_getSecretStringNoSecret(t *testing.T, api *secretsmanager.Client)
)

secretName := "NoSuchSecret"
_, err := cache.GetSecretString(secretName)
_, err := cache.GetSecretString(context.Background(), secretName)

var rnfe *types.ResourceNotFoundException

Expand Down
16 changes: 8 additions & 8 deletions secretcache/cache.go
Original file line number Diff line number Diff line change
Expand Up @@ -88,16 +88,16 @@ func (c *Cache) getCachedSecret(secretId string) *secretCacheItem {

// GetSecretString gets the secret string value from the cache for given secret id and a default version stage.
// Returns the secret sting and an error if operation failed.
func (c *Cache) GetSecretString(secretId string) (string, error) {
return c.GetSecretStringWithStage(secretId, DefaultVersionStage)
func (c *Cache) GetSecretString(ctx context.Context, secretId string) (string, error) {
return c.GetSecretStringWithStage(ctx, secretId, DefaultVersionStage)
}

// GetSecretStringWithStage gets the secret string value from the cache for given secret id and version stage.
// Returns the secret sting and an error if operation failed.
func (c *Cache) GetSecretStringWithStage(secretId string, versionStage string) (string, error) {
func (c *Cache) GetSecretStringWithStage(ctx context.Context, secretId string, versionStage string) (string, error) {
secretCacheItem := c.getCachedSecret(secretId)

getSecretValueOutput, err := secretCacheItem.getSecretValue(versionStage)
getSecretValueOutput, err := secretCacheItem.getSecretValue(ctx, versionStage)

if err != nil {
return "", err
Expand All @@ -116,16 +116,16 @@ func (c *Cache) GetSecretStringWithStage(secretId string, versionStage string) (

// GetSecretBinary gets the secret binary value from the cache for given secret id and a default version stage.
// Returns the secret binary and an error if operation failed.
func (c *Cache) GetSecretBinary(secretId string) ([]byte, error) {
return c.GetSecretBinaryWithStage(secretId, DefaultVersionStage)
func (c *Cache) GetSecretBinary(ctx context.Context, secretId string) ([]byte, error) {
return c.GetSecretBinaryWithStage(ctx, secretId, DefaultVersionStage)
}

// GetSecretBinaryWithStage gets the secret binary value from the cache for given secret id and version stage.
// Returns the secret binary and an error if operation failed.
func (c *Cache) GetSecretBinaryWithStage(secretId string, versionStage string) ([]byte, error) {
func (c *Cache) GetSecretBinaryWithStage(ctx context.Context, secretId string, versionStage string) ([]byte, error) {
secretCacheItem := c.getCachedSecret(secretId)

getSecretValueOutput, err := secretCacheItem.getSecretValue(versionStage)
getSecretValueOutput, err := secretCacheItem.getSecretValue(ctx, versionStage)

if err != nil {
return nil, err
Expand Down
5 changes: 3 additions & 2 deletions secretcache/cacheHook_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ package secretcache_test

import (
"bytes"
"context"
"testing"

"github.com/aws/aws-secretsmanager-caching-go/v2/secretcache"
Expand Down Expand Up @@ -44,7 +45,7 @@ func TestCacheHookString(t *testing.T) {
func(c *secretcache.Cache) { c.CacheConfig.Hook = hook },
)

result, err := secretCache.GetSecretString(secretId)
result, err := secretCache.GetSecretString(context.Background(), secretId)

if err != nil {
t.Fatalf("Unexpected error - %s", err.Error())
Expand Down Expand Up @@ -75,7 +76,7 @@ func TestCacheHookBinary(t *testing.T) {
func(c *secretcache.Cache) { c.CacheConfig.Hook = hook },
)

result, err := secretCache.GetSecretBinary(secretId)
result, err := secretCache.GetSecretBinary(context.Background(), secretId)

if err != nil {
t.Fatalf("Unexpected error - %s", err.Error())
Expand Down
14 changes: 7 additions & 7 deletions secretcache/cacheItem.go
Original file line number Diff line number Diff line change
Expand Up @@ -78,12 +78,12 @@ func (ci *secretCacheItem) getVersionId(versionStage string) (string, bool) {

// executeRefresh performs the actual refresh of the cached secret information.
// Returns the DescribeSecret API result and an error if call failed.
func (ci *secretCacheItem) executeRefresh() (*secretsmanager.DescribeSecretOutput, error) {
func (ci *secretCacheItem) executeRefresh(ctx context.Context) (*secretsmanager.DescribeSecretOutput, error) {
input := &secretsmanager.DescribeSecretInput{
SecretId: &ci.secretId,
}

result, err := ci.client.DescribeSecret(context.Background(), input, addUserAgent)
result, err := ci.client.DescribeSecret(ctx, input, addUserAgent)

var maxTTL int64
if ci.config.CacheItemTTL == 0 {
Expand Down Expand Up @@ -130,14 +130,14 @@ func (ci *secretCacheItem) getVersion(versionStage string) (*cacheVersion, bool)
}

// refresh the cached object when needed.
func (ci *secretCacheItem) refresh() {
func (ci *secretCacheItem) refresh(ctx context.Context) {
if !ci.isRefreshNeeded() {
return
}

ci.refreshNeeded = false

result, err := ci.executeRefresh()
result, err := ci.executeRefresh(ctx)

if err != nil {
ci.errorCount++
Expand All @@ -156,7 +156,7 @@ func (ci *secretCacheItem) refresh() {

// getSecretValue gets the cached secret value for the given version stage.
// Returns the GetSecretValue API result and an error if operation fails.
func (ci *secretCacheItem) getSecretValue(versionStage string) (*secretsmanager.GetSecretValueOutput, error) {
func (ci *secretCacheItem) getSecretValue(ctx context.Context, versionStage string) (*secretsmanager.GetSecretValueOutput, error) {
if versionStage == "" && ci.config.VersionStage == "" {
versionStage = DefaultVersionStage
} else if versionStage == "" && ci.config.VersionStage != "" {
Expand All @@ -166,7 +166,7 @@ func (ci *secretCacheItem) getSecretValue(versionStage string) (*secretsmanager.
ci.mux.Lock()
defer ci.mux.Unlock()

ci.refresh()
ci.refresh(ctx)
version, ok := ci.getVersion(versionStage)

if !ok {
Expand All @@ -181,7 +181,7 @@ func (ci *secretCacheItem) getSecretValue(versionStage string) (*secretsmanager.
}

}
return version.getSecretValue()
return version.getSecretValue(ctx)
}

// setWithHook sets the cache item's data using the CacheHook, if one is configured.
Expand Down
4 changes: 2 additions & 2 deletions secretcache/cacheObjects_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ func TestMaxCacheTTL(t *testing.T) {
config := CacheConfig{CacheItemTTL: -1}
cacheItem.config = config

_, err := cacheItem.executeRefresh()
_, err := cacheItem.executeRefresh(context.Background())

if err == nil {
t.Fatalf("Expected error due to negative cache ttl")
Expand All @@ -81,7 +81,7 @@ func TestMaxCacheTTL(t *testing.T) {
config = CacheConfig{CacheItemTTL: 0}
cacheItem.config = config

_, err = cacheItem.executeRefresh()
_, err = cacheItem.executeRefresh(context.Background())

if err != nil {
t.Fatalf("Unexpected error on zero cache ttl")
Expand Down
12 changes: 6 additions & 6 deletions secretcache/cacheVersion.go
Original file line number Diff line number Diff line change
Expand Up @@ -43,14 +43,14 @@ func (cv *cacheVersion) isRefreshNeeded() bool {
}

// refresh the cached object when needed.
func (cv *cacheVersion) refresh() {
func (cv *cacheVersion) refresh(ctx context.Context) {
if !cv.isRefreshNeeded() {
return
}

cv.refreshNeeded = false

result, err := cv.executeRefresh()
result, err := cv.executeRefresh(ctx)

if err != nil {
cv.errorCount++
Expand All @@ -70,21 +70,21 @@ func (cv *cacheVersion) refresh() {

// executeRefresh performs the actual refresh of the cached secret information.
// Returns the GetSecretValue API result and an error if operation fails.
func (cv *cacheVersion) executeRefresh() (*secretsmanager.GetSecretValueOutput, error) {
func (cv *cacheVersion) executeRefresh(ctx context.Context) (*secretsmanager.GetSecretValueOutput, error) {
input := &secretsmanager.GetSecretValueInput{
SecretId: &cv.secretId,
VersionId: &cv.versionId,
}
return cv.client.GetSecretValue(context.Background(), input, addUserAgent)
return cv.client.GetSecretValue(ctx, input, addUserAgent)
}

// getSecretValue gets the cached secret version value.
// Returns the GetSecretValue API cached result and an error if operation fails.
func (cv *cacheVersion) getSecretValue() (*secretsmanager.GetSecretValueOutput, error) {
func (cv *cacheVersion) getSecretValue(ctx context.Context) (*secretsmanager.GetSecretValueOutput, error) {
cv.mux.Lock()
defer cv.mux.Unlock()

cv.refresh()
cv.refresh(ctx)

return cv.getWithHook(), cv.err
}
Expand Down
Loading

0 comments on commit 9af7101

Please sign in to comment.