Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
95 changes: 75 additions & 20 deletions drpcmetadata/metadata.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,12 @@ package drpcmetadata

import (
"context"
"strings"

"github.com/zeebo/errs"
)

// AddPairs attaches metadata onto a context and return the context.
// AddPairs attaches metadata onto an incoming context and returns the context.
func AddPairs(ctx context.Context, metadata map[string]string) context.Context {
// Get returns a copy of metadata
newMetadata, ok := Get(ctx)
Expand All @@ -19,7 +20,47 @@ func AddPairs(ctx context.Context, metadata map[string]string) context.Context {
for k, v := range metadata {
newMetadata[k] = v
}
return context.WithValue(ctx, metadataKey{}, newMetadata)
return context.WithValue(ctx, incomingMetadataKey{}, newMetadata)
}

// AddPairsToOutgoingContext attaches metadata onto an outgoing context and
// returns the context.
func AddPairsToOutgoingContext(ctx context.Context, metadata map[string]string) context.Context {
// Get existing metadata
existingMd, ok := ctx.Value(outgoingMetadataKey{}).(map[string]string)
if !ok {
return ctx
}
newMetadata := make(map[string]string)
for k, v := range existingMd {
newMetadata[k] = v
}
for k, v := range metadata {
newMetadata[k] = v
}
return context.WithValue(ctx, outgoingMetadataKey{}, newMetadata)
}

// NewIncomingContext attaches new metadata onto a context and returns the
// context.
func NewIncomingContext(ctx context.Context,
metadata map[string]string) context.Context {
newMetadata := make(map[string]string)
for k, v := range metadata {
newMetadata[k] = v
}
return context.WithValue(ctx, incomingMetadataKey{}, newMetadata)
}

// NewOutgoingContext attaches new metadata onto a context and returns the
// context.
func NewOutgoingContext(ctx context.Context,
metadata map[string]string) context.Context {
newMetadata := make(map[string]string)
for k, v := range metadata {
newMetadata[k] = v
}
return context.WithValue(ctx, outgoingMetadataKey{}, newMetadata)
}

// Encode generates byte form of the metadata and appends it onto the passed in buffer.
Expand Down Expand Up @@ -53,44 +94,42 @@ func Decode(buf []byte) (map[string]string, error) {
return out, nil
}

type metadataKey struct{}
type incomingMetadataKey struct{}
type outgoingMetadataKey struct{}

// ClearContext removes all metadata from the context and returns a new context
// with no metadata attached.
// ClearContext removes all metadata from the incoming context and returns a new
// context with no metadata attached.
func ClearContext(ctx context.Context) context.Context {
return context.WithValue(ctx, metadataKey{}, nil)
return context.WithValue(ctx, incomingMetadataKey{}, nil)
}

// ClearContextExcept removes all metadata from the context except for the
// specified key. If the specified key doesn't exist in the metadata, it clears
// ClearContextExcept removes all metadata from the incoming context except for
// the specified key. If the specified key doesn't exist in the metadata, it clears
// all metadata. Returns a new context with only the specified key-value pair
// preserved.
func ClearContextExcept(ctx context.Context, key string) context.Context {
md, ok := Get(ctx)
if !ok {
return ClearContext(ctx)
}
value, ok := md[key]
value, ok := GetValue(ctx, key)
if !ok {
return ClearContext(ctx)
}
return context.WithValue(ctx, metadataKey{}, map[string]string{key: value})
return context.WithValue(ctx, incomingMetadataKey{},
map[string]string{strings.ToLower(key): value})
}

// Add associates a key/value pair on the context.
// Add associates a key/value pair on the incoming context.
func Add(ctx context.Context, key, value string) context.Context {
// Get returns a copy of metadata
metadata, ok := Get(ctx)
if !ok {
metadata = make(map[string]string)
}
metadata[key] = value
return context.WithValue(ctx, metadataKey{}, metadata)
return context.WithValue(ctx, incomingMetadataKey{}, metadata)
}

// Get returns all key/value pairs on the given context.
// Get returns all key/value pairs on the given incoming context.
func Get(ctx context.Context) (map[string]string, bool) {
metadata, ok := ctx.Value(metadataKey{}).(map[string]string)
metadata, ok := ctx.Value(incomingMetadataKey{}).(map[string]string)
if !ok {
return nil, false
}
Expand All @@ -102,9 +141,25 @@ func Get(ctx context.Context) (map[string]string, bool) {
return copy, true
}

// GetValue retrieves a specific value by key from the context's metadata.
// GetFromOutgoingContext returns all key/value pairs on the given incoming
// context.
func GetFromOutgoingContext(ctx context.Context) (map[string]string, bool) {
metadata, ok := ctx.Value(outgoingMetadataKey{}).(map[string]string)
if !ok {
return nil, false
}
// Return a copy to prevent mutation of the original map
copy := make(map[string]string)
for k, v := range metadata {
copy[k] = v
}
return copy, true
}

// GetValue retrieves a specific value by key from the incoming context's
// metadata.
func GetValue(ctx context.Context, key string) (string, bool) {
metadata, ok := Get(ctx)
metadata, ok := ctx.Value(incomingMetadataKey{}).(map[string]string)
if !ok {
return "", false
}
Expand Down
116 changes: 116 additions & 0 deletions drpcmetadata/metadata_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,24 @@ func TestAddGet(t *testing.T) {
}
}

func TestGetFromOutgoingContext(t *testing.T) {
ctx := context.Background()

ctx = context.WithValue(ctx, outgoingMetadataKey{}, map[string]string{
"foo": "bar",
"ak": "av",
"bk": "bv",
})

metadata, ok := GetFromOutgoingContext(ctx)
assert.That(t, ok)
assert.Equal(t, metadata, map[string]string{
"foo": "bar",
"ak": "av",
"bk": "bv",
})
}

func TestEncode(t *testing.T) {
t.Run("Empty Metadata", func(t *testing.T) {
var metadata map[string]string
Expand Down Expand Up @@ -136,3 +154,101 @@ func TestAddPairsImmutability(t *testing.T) {
assert.Equal(t, newMd["key1"], "val1")
assert.Equal(t, newMd["key2"], "val2")
}

func TestNewIncomingContext(t *testing.T) {
originalCtx := context.Background()
originalCtx = Add(originalCtx, "existing", "value")

newCtx := NewIncomingContext(originalCtx, map[string]string{
"key1": "value1",
"key2": "value2",
})
originalMd, ok := Get(originalCtx)
assert.That(t, ok)
assert.Equal(t, originalMd, map[string]string{"existing": "value"})

newMd, ok := Get(newCtx)
assert.That(t, ok)
assert.Equal(t, newMd, map[string]string{
"key1": "value1",
"key2": "value2",
})
}

func TestClearContext(t *testing.T) {
ctx := context.Background()
ctx = Add(ctx, "existing", "value")

ctx = ClearContext(ctx)
newMd, ok := Get(ctx)
assert.False(t, ok)
assert.Equal(t, newMd, map[string]string(nil))
}

func TestClearContextExcept(t *testing.T) {
ctx := context.Background()
ctx = AddPairs(ctx, map[string]string{
"key1": "value1", "key2": "value2",
})

ctx = ClearContextExcept(ctx, "key1")
md, ok := Get(ctx)
assert.That(t, ok)
assert.Equal(t, md, map[string]string{
"key1": "value1",
})

ctx = ClearContextExcept(ctx, "non-existent-key")
md, ok = Get(ctx)
assert.False(t, ok)
assert.Equal(t, md, map[string]string(nil))
}

func TestGetValue(t *testing.T) {
ctx := context.Background()

ctx = AddPairs(ctx, map[string]string{
"key1": "value1", "key2": "value2",
})

val, ok := GetValue(ctx, "non-existent-key")
assert.False(t, ok)
assert.Equal(t, val, "")

val, ok = GetValue(ctx, "key1")
assert.That(t, ok)
assert.Equal(t, val, "value1")

val, ok = GetValue(ctx, "key2")
assert.That(t, ok)
assert.Equal(t, val, "value2")

val, ok = GetValue(ctx, "Key1") // case-sensitivity test
assert.False(t, ok)
assert.Equal(t, val, "")
}

func TestNewOutgoingContext(t *testing.T) {
originalCtx := context.Background()

originalCtx = context.WithValue(originalCtx, outgoingMetadataKey{},
map[string]string{"existing-key": "existing-value"})

newCtx := NewOutgoingContext(originalCtx, map[string]string{
"key1": "value1",
"key2": "value2",
})

originalMd, ok := GetFromOutgoingContext(originalCtx)
assert.That(t, ok)
assert.Equal(t, originalMd, map[string]string{
"existing-key": "existing-value",
})

newMd, ok := GetFromOutgoingContext(newCtx)
assert.That(t, ok)
assert.Equal(t, newMd, map[string]string{
"key1": "value1",
"key2": "value2",
})
}