diff --git a/drpcmetadata/metadata.go b/drpcmetadata/metadata.go index c1e2276..f70abac 100644 --- a/drpcmetadata/metadata.go +++ b/drpcmetadata/metadata.go @@ -5,11 +5,12 @@ package drpcmetadata import ( "context" + "strings" "github.com/zeebo/errs" ) -// AddPairs attaches metadata onto a context and return the context. +// AddPairs attaches metadata onto an incoming context and returns the context. func AddPairs(ctx context.Context, metadata map[string]string) context.Context { // Get returns a copy of metadata newMetadata, ok := Get(ctx) @@ -19,7 +20,47 @@ func AddPairs(ctx context.Context, metadata map[string]string) context.Context { for k, v := range metadata { newMetadata[k] = v } - return context.WithValue(ctx, metadataKey{}, newMetadata) + return context.WithValue(ctx, incomingMetadataKey{}, newMetadata) +} + +// AddPairsToOutgoingContext attaches metadata onto an outgoing context and +// returns the context. +func AddPairsToOutgoingContext(ctx context.Context, metadata map[string]string) context.Context { + // Get existing metadata + existingMd, ok := ctx.Value(outgoingMetadataKey{}).(map[string]string) + if !ok { + return ctx + } + newMetadata := make(map[string]string) + for k, v := range existingMd { + newMetadata[k] = v + } + for k, v := range metadata { + newMetadata[k] = v + } + return context.WithValue(ctx, outgoingMetadataKey{}, newMetadata) +} + +// NewIncomingContext attaches new metadata onto a context and returns the +// context. +func NewIncomingContext(ctx context.Context, + metadata map[string]string) context.Context { + newMetadata := make(map[string]string) + for k, v := range metadata { + newMetadata[k] = v + } + return context.WithValue(ctx, incomingMetadataKey{}, newMetadata) +} + +// NewOutgoingContext attaches new metadata onto a context and returns the +// context. +func NewOutgoingContext(ctx context.Context, + metadata map[string]string) context.Context { + newMetadata := make(map[string]string) + for k, v := range metadata { + newMetadata[k] = v + } + return context.WithValue(ctx, outgoingMetadataKey{}, newMetadata) } // Encode generates byte form of the metadata and appends it onto the passed in buffer. @@ -53,31 +94,29 @@ func Decode(buf []byte) (map[string]string, error) { return out, nil } -type metadataKey struct{} +type incomingMetadataKey struct{} +type outgoingMetadataKey struct{} -// ClearContext removes all metadata from the context and returns a new context -// with no metadata attached. +// ClearContext removes all metadata from the incoming context and returns a new +// context with no metadata attached. func ClearContext(ctx context.Context) context.Context { - return context.WithValue(ctx, metadataKey{}, nil) + return context.WithValue(ctx, incomingMetadataKey{}, nil) } -// ClearContextExcept removes all metadata from the context except for the -// specified key. If the specified key doesn't exist in the metadata, it clears +// ClearContextExcept removes all metadata from the incoming context except for +// the specified key. If the specified key doesn't exist in the metadata, it clears // all metadata. Returns a new context with only the specified key-value pair // preserved. func ClearContextExcept(ctx context.Context, key string) context.Context { - md, ok := Get(ctx) - if !ok { - return ClearContext(ctx) - } - value, ok := md[key] + value, ok := GetValue(ctx, key) if !ok { return ClearContext(ctx) } - return context.WithValue(ctx, metadataKey{}, map[string]string{key: value}) + return context.WithValue(ctx, incomingMetadataKey{}, + map[string]string{strings.ToLower(key): value}) } -// Add associates a key/value pair on the context. +// Add associates a key/value pair on the incoming context. func Add(ctx context.Context, key, value string) context.Context { // Get returns a copy of metadata metadata, ok := Get(ctx) @@ -85,12 +124,12 @@ func Add(ctx context.Context, key, value string) context.Context { metadata = make(map[string]string) } metadata[key] = value - return context.WithValue(ctx, metadataKey{}, metadata) + return context.WithValue(ctx, incomingMetadataKey{}, metadata) } -// Get returns all key/value pairs on the given context. +// Get returns all key/value pairs on the given incoming context. func Get(ctx context.Context) (map[string]string, bool) { - metadata, ok := ctx.Value(metadataKey{}).(map[string]string) + metadata, ok := ctx.Value(incomingMetadataKey{}).(map[string]string) if !ok { return nil, false } @@ -102,9 +141,25 @@ func Get(ctx context.Context) (map[string]string, bool) { return copy, true } -// GetValue retrieves a specific value by key from the context's metadata. +// GetFromOutgoingContext returns all key/value pairs on the given incoming +// context. +func GetFromOutgoingContext(ctx context.Context) (map[string]string, bool) { + metadata, ok := ctx.Value(outgoingMetadataKey{}).(map[string]string) + if !ok { + return nil, false + } + // Return a copy to prevent mutation of the original map + copy := make(map[string]string) + for k, v := range metadata { + copy[k] = v + } + return copy, true +} + +// GetValue retrieves a specific value by key from the incoming context's +// metadata. func GetValue(ctx context.Context, key string) (string, bool) { - metadata, ok := Get(ctx) + metadata, ok := ctx.Value(incomingMetadataKey{}).(map[string]string) if !ok { return "", false } diff --git a/drpcmetadata/metadata_test.go b/drpcmetadata/metadata_test.go index 2c713fc..fa158e9 100644 --- a/drpcmetadata/metadata_test.go +++ b/drpcmetadata/metadata_test.go @@ -45,6 +45,24 @@ func TestAddGet(t *testing.T) { } } +func TestGetFromOutgoingContext(t *testing.T) { + ctx := context.Background() + + ctx = context.WithValue(ctx, outgoingMetadataKey{}, map[string]string{ + "foo": "bar", + "ak": "av", + "bk": "bv", + }) + + metadata, ok := GetFromOutgoingContext(ctx) + assert.That(t, ok) + assert.Equal(t, metadata, map[string]string{ + "foo": "bar", + "ak": "av", + "bk": "bv", + }) +} + func TestEncode(t *testing.T) { t.Run("Empty Metadata", func(t *testing.T) { var metadata map[string]string @@ -136,3 +154,101 @@ func TestAddPairsImmutability(t *testing.T) { assert.Equal(t, newMd["key1"], "val1") assert.Equal(t, newMd["key2"], "val2") } + +func TestNewIncomingContext(t *testing.T) { + originalCtx := context.Background() + originalCtx = Add(originalCtx, "existing", "value") + + newCtx := NewIncomingContext(originalCtx, map[string]string{ + "key1": "value1", + "key2": "value2", + }) + originalMd, ok := Get(originalCtx) + assert.That(t, ok) + assert.Equal(t, originalMd, map[string]string{"existing": "value"}) + + newMd, ok := Get(newCtx) + assert.That(t, ok) + assert.Equal(t, newMd, map[string]string{ + "key1": "value1", + "key2": "value2", + }) +} + +func TestClearContext(t *testing.T) { + ctx := context.Background() + ctx = Add(ctx, "existing", "value") + + ctx = ClearContext(ctx) + newMd, ok := Get(ctx) + assert.False(t, ok) + assert.Equal(t, newMd, map[string]string(nil)) +} + +func TestClearContextExcept(t *testing.T) { + ctx := context.Background() + ctx = AddPairs(ctx, map[string]string{ + "key1": "value1", "key2": "value2", + }) + + ctx = ClearContextExcept(ctx, "key1") + md, ok := Get(ctx) + assert.That(t, ok) + assert.Equal(t, md, map[string]string{ + "key1": "value1", + }) + + ctx = ClearContextExcept(ctx, "non-existent-key") + md, ok = Get(ctx) + assert.False(t, ok) + assert.Equal(t, md, map[string]string(nil)) +} + +func TestGetValue(t *testing.T) { + ctx := context.Background() + + ctx = AddPairs(ctx, map[string]string{ + "key1": "value1", "key2": "value2", + }) + + val, ok := GetValue(ctx, "non-existent-key") + assert.False(t, ok) + assert.Equal(t, val, "") + + val, ok = GetValue(ctx, "key1") + assert.That(t, ok) + assert.Equal(t, val, "value1") + + val, ok = GetValue(ctx, "key2") + assert.That(t, ok) + assert.Equal(t, val, "value2") + + val, ok = GetValue(ctx, "Key1") // case-sensitivity test + assert.False(t, ok) + assert.Equal(t, val, "") +} + +func TestNewOutgoingContext(t *testing.T) { + originalCtx := context.Background() + + originalCtx = context.WithValue(originalCtx, outgoingMetadataKey{}, + map[string]string{"existing-key": "existing-value"}) + + newCtx := NewOutgoingContext(originalCtx, map[string]string{ + "key1": "value1", + "key2": "value2", + }) + + originalMd, ok := GetFromOutgoingContext(originalCtx) + assert.That(t, ok) + assert.Equal(t, originalMd, map[string]string{ + "existing-key": "existing-value", + }) + + newMd, ok := GetFromOutgoingContext(newCtx) + assert.That(t, ok) + assert.Equal(t, newMd, map[string]string{ + "key1": "value1", + "key2": "value2", + }) +}