diff --git a/go.mod b/go.mod index db368c3..3845158 100644 --- a/go.mod +++ b/go.mod @@ -10,6 +10,7 @@ require ( require ( github.com/davecgh/go-spew v1.1.0 // indirect + github.com/google/uuid v1.3.0 // indirect github.com/jmespath/go-jmespath v0.4.0 // indirect github.com/pmezard/go-difflib v1.0.0 // indirect gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c // indirect diff --git a/go.sum b/go.sum index 85cd547..edbd508 100644 --- a/go.sum +++ b/go.sum @@ -2,6 +2,8 @@ github.com/aws/aws-sdk-go v1.43.24 h1:7c2PniJ0wpmWsIA6OtYBw6wS7DF0IjbhvPq+0ZQYNX github.com/aws/aws-sdk-go v1.43.24/go.mod h1:y4AeaBuwd2Lk+GepC1E9v0qOiTws0MIWAX4oIKwKHZo= github.com/davecgh/go-spew v1.1.0 h1:ZDRjVQ15GmhC3fiQ8ni8+OwkZQO4DARzQgrnXU1Liz8= github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/google/uuid v1.3.0 h1:t6JiXgmwXMjEs8VusXIJk2BXHsn+wx8BZdTaoZ5fu7I= +github.com/google/uuid v1.3.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= github.com/jmespath/go-jmespath v0.4.0 h1:BEgLn5cpjn8UN1mAw4NjwDrS35OdebyEtFe+9YPoQUg= github.com/jmespath/go-jmespath v0.4.0/go.mod h1:T8mJZnbsbmF+m6zOOFylbeCJqk5+pHWvzYPziyZiYoo= github.com/jmespath/go-jmespath/internal/testify v1.5.1 h1:shLQSRRSCCPj3f2gpwzGwWFoC7ycTf1rcQZHOlsJ6N8= diff --git a/publisher/sns/mock_test.go b/publisher/sns/mock_test.go index bc96cea..61d3904 100644 --- a/publisher/sns/mock_test.go +++ b/publisher/sns/mock_test.go @@ -15,3 +15,10 @@ func (p *snsPublisherMock) PublishWithContext(ctx context.Context, input *sns.Pu p.queue <- input.Message return &sns.PublishOutput{}, nil } + +func (p *snsPublisherMock) PublishBatchWithContext(ctx context.Context, input *sns.PublishBatchInput, o ...request.Option) (*sns.PublishBatchOutput, error) { + for _, entry := range input.PublishBatchRequestEntries { + p.queue <- entry.Message + } + return &sns.PublishBatchOutput{}, nil +} diff --git a/publisher/sns/sns.go b/publisher/sns/sns.go index a089bed..e215fd6 100644 --- a/publisher/sns/sns.go +++ b/publisher/sns/sns.go @@ -9,12 +9,14 @@ import ( "github.com/aws/aws-sdk-go/aws/request" "github.com/aws/aws-sdk-go/aws/session" "github.com/aws/aws-sdk-go/service/sns" + "github.com/google/uuid" ) // sender is the interface to sns.SNS. Its sole purpose is to make // Publisher.service and interface that we can mock for testing. type sender interface { PublishWithContext(ctx context.Context, input *sns.PublishInput, o ...request.Option) (*sns.PublishOutput, error) + PublishBatchWithContext(ctx context.Context, input *sns.PublishBatchInput, o ...request.Option) (*sns.PublishBatchOutput, error) } // Config holds the info required to work with AWS SNS @@ -58,6 +60,74 @@ func (p *Publisher) Publish(ctx context.Context, msg interface{}) error { return err } +// PublishBatch allows SNS Publisher to implement the publisher.Publisher interface +// and publish messages in a single batch to an AWS SNS backend. Since AWS SNS batch +// publish can only handle a maximum payload of 10 messages at a time, the messages +// supplied will be published in batches of 10. For this reason, message sets are best +// kept under 100 messages so that all messages can be published in 10 tries. In case +// of failure when parsing or publishing any of the messages, this function will stop +// further publishing and return an error +func (p *Publisher) PublishBatch(ctx context.Context, msgs []interface{}) error { + var ( + defaultMessageGroupID = "default" + err error + ) + + isFifo := strings.Contains(strings.ToLower(p.cfg.TopicArn), "fifo") + + var ( + numPublishedMessages = 0 + start = 0 + end = 10 // 10 is the maximum batch size for SNS.PublishBatch + ) + if end > len(msgs) { + end = len(msgs) + } + for numPublishedMessages < len(msgs) { + var ( + requestEntries = make([]*sns.PublishBatchRequestEntry, 0) + ) + for idx := start; idx < end; idx++ { + msg := msgs[idx] + + b, err := json.Marshal(msg) + if err != nil { + return err + } + + entryId := uuid.New().String() + requestEntry := &sns.PublishBatchRequestEntry{ + Id: aws.String(entryId), + Message: aws.String(string(b)), + } + + if isFifo { + requestEntry.MessageGroupId = &defaultMessageGroupID + } + + requestEntries = append(requestEntries, requestEntry) + } + + input := &sns.PublishBatchInput{ + PublishBatchRequestEntries: requestEntries, + TopicArn: &p.cfg.TopicArn, + } + _, err = p.sns.PublishBatchWithContext(ctx, input) + if err != nil { + return err + } + + numPublishedMessages += len(requestEntries) + start = end + end += 10 + if end > len(msgs) { + end = len(msgs) + } + } + + return err +} + func defaultPublisherConfig(cfg *Config) { if cfg.AWSSession == nil { cfg.AWSSession = session.Must(session.NewSession()) diff --git a/publisher/sns/sns_test.go b/publisher/sns/sns_test.go index d519461..b0feddf 100644 --- a/publisher/sns/sns_test.go +++ b/publisher/sns/sns_test.go @@ -26,6 +26,31 @@ func TestPublisher(t *testing.T) { require.Equal(t, *publishedMessage, `{"msg":"message"}`) } +func TestPublisherBatch(t *testing.T) { + inputs := []interface{}{ + jsonString(`{"key":"val1"}`), + jsonString(`{"key":"val2"}`), + } + + queue := make(chan *string, len(inputs)) + defer close(queue) + + pubs := New(Config{}) + pubs.sns = &snsPublisherMock{queue: queue} + + require.NoError(t, pubs.PublishBatch(context.TODO(), inputs)) + + idx := 0 + for v := range queue { + publishedMessage := *v + require.Equal(t, jsonString(publishedMessage), inputs[idx]) + idx++ + if idx >= len(inputs) { + break + } + } +} + func TestPublisherDefaults(t *testing.T) { tt := []struct {